396 lines
14 KiB
Go
396 lines
14 KiB
Go
package plant
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
qdrant "github.com/qdrant/go-client/qdrant"
|
|
openai "github.com/sashabaranov/go-openai"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
"sundynix-go/global"
|
|
plantModel "sundynix-go/model/plant"
|
|
"sundynix-go/model/system"
|
|
systemService "sundynix-go/service/system"
|
|
)
|
|
|
|
type AiRagService struct{}
|
|
|
|
var sysAiConfigService = systemService.SysAiConfigService{}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// OpenAI 客户端构建
|
|
// ──────────────────────────────────────────────────────────────
|
|
|
|
func getChatClient(cfg *system.SysAiConfig) *openai.Client {
|
|
config := openai.DefaultConfig(cfg.ChatApiKey)
|
|
if cfg.ChatApiUrl != "" {
|
|
config.BaseURL = cfg.ChatApiUrl
|
|
}
|
|
return openai.NewClientWithConfig(config)
|
|
}
|
|
|
|
func getEmbeddingClient(cfg *system.SysAiConfig) *openai.Client {
|
|
config := openai.DefaultConfig(cfg.EmbeddingApiKey)
|
|
if cfg.EmbeddingApiUrl != "" {
|
|
config.BaseURL = cfg.EmbeddingApiUrl
|
|
}
|
|
return openai.NewClientWithConfig(config)
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// Qdrant gRPC 连接
|
|
// ──────────────────────────────────────────────────────────────
|
|
|
|
func newQdrantConn(cfg *system.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
|
|
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
|
|
addr = strings.TrimPrefix(addr, "https://")
|
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
|
|
}
|
|
ctx := context.Background()
|
|
if cfg.QdrantApiKey != "" {
|
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
|
|
}
|
|
return conn, ctx, nil
|
|
}
|
|
|
|
// EnsureCollection 确保 Collection 存在,不存在则创建
|
|
func EnsureCollection(cfg *system.SysAiConfig) error {
|
|
conn, ctx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
dim := uint64(cfg.VectorDimension)
|
|
if dim == 0 {
|
|
dim = 104
|
|
}
|
|
collClient := qdrant.NewCollectionsClient(conn)
|
|
if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil {
|
|
return nil // 已存在
|
|
}
|
|
_, err = collClient.Create(ctx, &qdrant.CreateCollection{
|
|
CollectionName: cfg.QdrantCollection,
|
|
VectorsConfig: &qdrant.VectorsConfig{
|
|
Config: &qdrant.VectorsConfig_Params{
|
|
Params: &qdrant.VectorParams{Size: dim, Distance: qdrant.Distance_Cosine},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("qdrant create collection failed: %w", err)
|
|
}
|
|
global.Logger.Info("Qdrant collection created", zap.String("collection", cfg.QdrantCollection))
|
|
return nil
|
|
}
|
|
|
|
// wikiID → Qdrant point UUID(确保幂等)
|
|
func wikiToQdrantID(wikiId string) string {
|
|
return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiId)).String()
|
|
}
|
|
|
|
// buildWikiText 拼接用于向量化的文本语料
|
|
func buildWikiText(w plantModel.Wiki) string {
|
|
return fmt.Sprintf(
|
|
"植物名字:%s. 拉丁名:%s. 科属:%s. 生命周期:%s. 生长习性:%s. 病虫害:%s. 光照类型:%s. 最佳温度:%s.",
|
|
w.Name, w.LatinName, w.Genus, w.LifeCycle, w.GrowthHabit,
|
|
w.PestsDiseases, w.LightType, w.OptimalTempPeriod,
|
|
)
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// SyncSingleWikiAsync 异步同步单条百科到 Qdrant(新增/更新时调用)
|
|
// 同步成功后将 is_vector_synced 置为 1
|
|
// ──────────────────────────────────────────────────────────────
|
|
func (s *AiRagService) SyncSingleWikiAsync(wikiId string) {
|
|
go func() {
|
|
if err := s.syncSingleWiki(wikiId); err != nil {
|
|
global.Logger.Error("Async sync wiki to Qdrant failed", zap.String("wiki_id", wikiId), zap.Error(err))
|
|
}
|
|
}()
|
|
}
|
|
|
|
// SyncSingleWiki 同步同步单条百科到 Qdrant(用于API直接调用)
|
|
func (s *AiRagService) SyncSingleWiki(wikiId string) error {
|
|
return s.syncSingleWiki(wikiId)
|
|
}
|
|
|
|
func (s *AiRagService) syncSingleWiki(wikiId string) error {
|
|
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("no active ai config: %w", err)
|
|
}
|
|
if err = EnsureCollection(cfg); err != nil {
|
|
global.Logger.Warn("EnsureCollection warn", zap.Error(err))
|
|
}
|
|
|
|
var w plantModel.Wiki
|
|
if err = global.DB.Where("id = ?", wikiId).First(&w).Error; err != nil {
|
|
return fmt.Errorf("wiki not found: %w", err)
|
|
}
|
|
|
|
text := buildWikiText(w)
|
|
embClient := getEmbeddingClient(cfg)
|
|
embResp, err := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
|
Input: []string{text},
|
|
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("embedding failed: %w", err)
|
|
}
|
|
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
_, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: []*qdrant.PointStruct{{
|
|
Id: qdrant.NewID(wikiToQdrantID(wikiId)),
|
|
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
|
|
Payload: map[string]*qdrant.Value{
|
|
"wiki_id": qdrant.NewValueString(w.Id),
|
|
"name": qdrant.NewValueString(w.Name),
|
|
"full_text": qdrant.NewValueString(text),
|
|
},
|
|
}},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("qdrant upsert failed: %w", err)
|
|
}
|
|
|
|
// 更新同步状态
|
|
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 1).Error
|
|
global.Logger.Info("Wiki synced to Qdrant", zap.String("wiki_id", wikiId))
|
|
return nil
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// DeleteFromQdrant 从 Qdrant 删除单条植物的向量点位
|
|
// 删除成功后将 is_vector_synced 置为 0
|
|
// ──────────────────────────────────────────────────────────────
|
|
func (s *AiRagService) DeleteFromQdrant(wikiId string) error {
|
|
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("no active ai config: %w", err)
|
|
}
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
qID := wikiToQdrantID(wikiId)
|
|
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: &qdrant.PointsSelector{
|
|
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
|
|
Points: &qdrant.PointsIdsList{
|
|
Ids: []*qdrant.PointId{qdrant.NewID(qID)},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("qdrant delete failed: %w", err)
|
|
}
|
|
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 0).Error
|
|
global.Logger.Info("Wiki deleted from Qdrant", zap.String("wiki_id", wikiId))
|
|
return nil
|
|
}
|
|
|
|
// DeleteFromQdrantBatch 批量从 Qdrant 删除(用于批量删除百科时)
|
|
func (s *AiRagService) DeleteFromQdrantBatch(wikiIds []string) {
|
|
go func() {
|
|
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
|
if err != nil {
|
|
global.Logger.Warn("No active ai config for batch qdrant delete", zap.Error(err))
|
|
return
|
|
}
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
global.Logger.Warn("Qdrant connect failed for batch delete", zap.Error(err))
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
var ids []*qdrant.PointId
|
|
for _, wid := range wikiIds {
|
|
ids = append(ids, qdrant.NewID(wikiToQdrantID(wid)))
|
|
}
|
|
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: &qdrant.PointsSelector{
|
|
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
|
|
Points: &qdrant.PointsIdsList{Ids: ids},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
global.Logger.Error("Qdrant batch delete failed", zap.Error(err))
|
|
} else {
|
|
global.Logger.Info("Qdrant batch delete done", zap.Int("count", len(wikiIds)))
|
|
}
|
|
}()
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// SyncWikiToQdrant 全量同步(后台操作/手动触发)
|
|
// ──────────────────────────────────────────────────────────────
|
|
func (s *AiRagService) SyncWikiToQdrant() error {
|
|
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err = EnsureCollection(cfg); err != nil {
|
|
global.Logger.Warn("EnsureCollection failed, continuing", zap.Error(err))
|
|
}
|
|
|
|
var wikis []plantModel.Wiki
|
|
if err = global.DB.Find(&wikis).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
embClient := getEmbeddingClient(cfg)
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
var successIds []string
|
|
for _, w := range wikis {
|
|
text := buildWikiText(w)
|
|
embResp, embErr := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
|
Input: []string{text},
|
|
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
|
})
|
|
if embErr != nil {
|
|
global.Logger.Error("Embedding failed", zap.String("wiki_id", w.Id), zap.Error(embErr))
|
|
continue
|
|
}
|
|
_, upsertErr := ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: []*qdrant.PointStruct{{
|
|
Id: qdrant.NewID(wikiToQdrantID(w.Id)),
|
|
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
|
|
Payload: map[string]*qdrant.Value{
|
|
"wiki_id": qdrant.NewValueString(w.Id),
|
|
"name": qdrant.NewValueString(w.Name),
|
|
"full_text": qdrant.NewValueString(text),
|
|
},
|
|
}},
|
|
})
|
|
if upsertErr != nil {
|
|
global.Logger.Error("Qdrant upsert failed", zap.String("wiki_id", w.Id), zap.Error(upsertErr))
|
|
} else {
|
|
successIds = append(successIds, w.Id)
|
|
}
|
|
}
|
|
|
|
// 批量更新同步状态
|
|
if len(successIds) > 0 {
|
|
_ = global.DB.Model(&plantModel.Wiki{}).Where("id IN ?", successIds).Update("is_vector_synced", 1).Error
|
|
}
|
|
global.Logger.Info("SyncWikiToQdrant done", zap.Int("total", len(wikis)), zap.Int("success", len(successIds)))
|
|
return nil
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────
|
|
// PlantChatStreamRAG 向量检索 + 大模型流式对话
|
|
// ──────────────────────────────────────────────────────────────
|
|
func (s *AiRagService) PlantChatStreamRAG(ctx context.Context, userQuery string, onData func(chunk string) error) error {
|
|
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
embClient := getEmbeddingClient(cfg)
|
|
chatClient := getChatClient(cfg)
|
|
|
|
embResp, err := embClient.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
|
Input: []string{userQuery},
|
|
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("向量化查询失败: %w", err)
|
|
}
|
|
|
|
conn, qdCtx, connErr := newQdrantConn(cfg)
|
|
var contextText string
|
|
if connErr == nil {
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
limit := uint64(3)
|
|
searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Vector: embResp.Data[0].Embedding,
|
|
Limit: limit,
|
|
WithPayload: &qdrant.WithPayloadSelector{
|
|
SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true},
|
|
},
|
|
})
|
|
if searchErr != nil {
|
|
global.Logger.Warn("Qdrant search failed, using empty context", zap.Error(searchErr))
|
|
} else {
|
|
for _, pt := range searchRes.GetResult() {
|
|
if txt, ok := pt.GetPayload()["full_text"]; ok {
|
|
contextText += txt.GetStringValue() + "\n"
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
global.Logger.Warn("Qdrant connect failed, skipping RAG", zap.Error(connErr))
|
|
}
|
|
|
|
systemPrompt := "你是一个专业的植物百科助手,请基于以下知识库信息回答用户问题。如果知识库无相关信息,结合你的通用知识作答。\n"
|
|
if contextText != "" {
|
|
systemPrompt += "--- 知识库 ---\n" + contextText + "\n--------------\n"
|
|
}
|
|
|
|
stream, err := chatClient.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
|
Model: cfg.ChatModelName,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
|
{Role: openai.ChatMessageRoleUser, Content: userQuery},
|
|
},
|
|
Stream: true,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("大模型调用失败: %w", err)
|
|
}
|
|
defer stream.Close()
|
|
|
|
for {
|
|
resp, recvErr := stream.Recv()
|
|
if errors.Is(recvErr, io.EOF) {
|
|
break
|
|
}
|
|
if recvErr != nil {
|
|
return recvErr
|
|
}
|
|
if len(resp.Choices) > 0 {
|
|
if content := resp.Choices[0].Delta.Content; content != "" {
|
|
if writeErr := onData(content); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|