package plant import ( "context" "errors" "fmt" "io" "strings" "github.com/google/uuid" qdrant "github.com/qdrant/go-client/qdrant" openai "github.com/sashabaranov/go-openai" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "sundynix-go/global" plantModel "sundynix-go/model/plant" "sundynix-go/model/system" systemService "sundynix-go/service/system" ) type AiRagService struct{} var sysAiConfigService = systemService.SysAiConfigService{} // ────────────────────────────────────────────────────────────── // OpenAI 客户端构建 // ────────────────────────────────────────────────────────────── func getChatClient(cfg *system.SysAiConfig) *openai.Client { config := openai.DefaultConfig(cfg.ChatApiKey) if cfg.ChatApiUrl != "" { config.BaseURL = cfg.ChatApiUrl } return openai.NewClientWithConfig(config) } func getEmbeddingClient(cfg *system.SysAiConfig) *openai.Client { config := openai.DefaultConfig(cfg.EmbeddingApiKey) if cfg.EmbeddingApiUrl != "" { config.BaseURL = cfg.EmbeddingApiUrl } return openai.NewClientWithConfig(config) } // ────────────────────────────────────────────────────────────── // Qdrant gRPC 连接 // ────────────────────────────────────────────────────────────── func newQdrantConn(cfg *system.SysAiConfig) (*grpc.ClientConn, context.Context, error) { addr := strings.TrimPrefix(cfg.QdrantUrl, "http://") addr = strings.TrimPrefix(addr, "https://") conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err) } ctx := context.Background() if cfg.QdrantApiKey != "" { ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey)) } return conn, ctx, nil } // EnsureCollection 确保 Collection 存在,不存在则创建 func EnsureCollection(cfg *system.SysAiConfig) error { conn, ctx, err := newQdrantConn(cfg) if err != nil { return err } defer conn.Close() dim := uint64(cfg.VectorDimension) if dim == 0 { dim = 104 } collClient := qdrant.NewCollectionsClient(conn) if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil { return nil // 已存在 } _, err = collClient.Create(ctx, &qdrant.CreateCollection{ CollectionName: cfg.QdrantCollection, VectorsConfig: &qdrant.VectorsConfig{ Config: &qdrant.VectorsConfig_Params{ Params: &qdrant.VectorParams{Size: dim, Distance: qdrant.Distance_Cosine}, }, }, }) if err != nil { return fmt.Errorf("qdrant create collection failed: %w", err) } global.Logger.Info("Qdrant collection created", zap.String("collection", cfg.QdrantCollection)) return nil } // wikiID → Qdrant point UUID(确保幂等) func wikiToQdrantID(wikiId string) string { return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiId)).String() } // buildWikiText 拼接用于向量化的文本语料 func buildWikiText(w plantModel.Wiki) string { return fmt.Sprintf( "植物名字:%s. 拉丁名:%s. 科属:%s. 生命周期:%s. 生长习性:%s. 病虫害:%s. 光照类型:%s. 最佳温度:%s.", w.Name, w.LatinName, w.Genus, w.LifeCycle, w.GrowthHabit, w.PestsDiseases, w.LightType, w.OptimalTempPeriod, ) } // ────────────────────────────────────────────────────────────── // SyncSingleWikiAsync 异步同步单条百科到 Qdrant(新增/更新时调用) // 同步成功后将 is_vector_synced 置为 1 // ────────────────────────────────────────────────────────────── func (s *AiRagService) SyncSingleWikiAsync(wikiId string) { go func() { if err := s.syncSingleWiki(wikiId); err != nil { global.Logger.Error("Async sync wiki to Qdrant failed", zap.String("wiki_id", wikiId), zap.Error(err)) } }() } // SyncSingleWiki 同步同步单条百科到 Qdrant(用于API直接调用) func (s *AiRagService) SyncSingleWiki(wikiId string) error { return s.syncSingleWiki(wikiId) } func (s *AiRagService) syncSingleWiki(wikiId string) error { cfg, err := sysAiConfigService.GetActiveAiConfig() if err != nil { return fmt.Errorf("no active ai config: %w", err) } if err = EnsureCollection(cfg); err != nil { global.Logger.Warn("EnsureCollection warn", zap.Error(err)) } var w plantModel.Wiki if err = global.DB.Where("id = ?", wikiId).First(&w).Error; err != nil { return fmt.Errorf("wiki not found: %w", err) } text := buildWikiText(w) embClient := getEmbeddingClient(cfg) embResp, err := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Input: []string{text}, Model: openai.EmbeddingModel(cfg.EmbeddingModelName), }) if err != nil { return fmt.Errorf("embedding failed: %w", err) } conn, qdCtx, err := newQdrantConn(cfg) if err != nil { return err } defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) _, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{ CollectionName: cfg.QdrantCollection, Points: []*qdrant.PointStruct{{ Id: qdrant.NewID(wikiToQdrantID(wikiId)), Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...), Payload: map[string]*qdrant.Value{ "wiki_id": qdrant.NewValueString(w.Id), "name": qdrant.NewValueString(w.Name), "full_text": qdrant.NewValueString(text), }, }}, }) if err != nil { return fmt.Errorf("qdrant upsert failed: %w", err) } // 更新同步状态 _ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 1).Error global.Logger.Info("Wiki synced to Qdrant", zap.String("wiki_id", wikiId)) return nil } // ────────────────────────────────────────────────────────────── // DeleteFromQdrant 从 Qdrant 删除单条植物的向量点位 // 删除成功后将 is_vector_synced 置为 0 // ────────────────────────────────────────────────────────────── func (s *AiRagService) DeleteFromQdrant(wikiId string) error { cfg, err := sysAiConfigService.GetActiveAiConfig() if err != nil { return fmt.Errorf("no active ai config: %w", err) } conn, qdCtx, err := newQdrantConn(cfg) if err != nil { return err } defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) qID := wikiToQdrantID(wikiId) _, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{ CollectionName: cfg.QdrantCollection, Points: &qdrant.PointsSelector{ PointsSelectorOneOf: &qdrant.PointsSelector_Points{ Points: &qdrant.PointsIdsList{ Ids: []*qdrant.PointId{qdrant.NewID(qID)}, }, }, }, }) if err != nil { return fmt.Errorf("qdrant delete failed: %w", err) } _ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 0).Error global.Logger.Info("Wiki deleted from Qdrant", zap.String("wiki_id", wikiId)) return nil } // DeleteFromQdrantBatch 批量从 Qdrant 删除(用于批量删除百科时) func (s *AiRagService) DeleteFromQdrantBatch(wikiIds []string) { go func() { cfg, err := sysAiConfigService.GetActiveAiConfig() if err != nil { global.Logger.Warn("No active ai config for batch qdrant delete", zap.Error(err)) return } conn, qdCtx, err := newQdrantConn(cfg) if err != nil { global.Logger.Warn("Qdrant connect failed for batch delete", zap.Error(err)) return } defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) var ids []*qdrant.PointId for _, wid := range wikiIds { ids = append(ids, qdrant.NewID(wikiToQdrantID(wid))) } _, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{ CollectionName: cfg.QdrantCollection, Points: &qdrant.PointsSelector{ PointsSelectorOneOf: &qdrant.PointsSelector_Points{ Points: &qdrant.PointsIdsList{Ids: ids}, }, }, }) if err != nil { global.Logger.Error("Qdrant batch delete failed", zap.Error(err)) } else { global.Logger.Info("Qdrant batch delete done", zap.Int("count", len(wikiIds))) } }() } // ────────────────────────────────────────────────────────────── // SyncWikiToQdrant 全量同步(后台操作/手动触发) // ────────────────────────────────────────────────────────────── func (s *AiRagService) SyncWikiToQdrant() error { cfg, err := sysAiConfigService.GetActiveAiConfig() if err != nil { return err } if err = EnsureCollection(cfg); err != nil { global.Logger.Warn("EnsureCollection failed, continuing", zap.Error(err)) } var wikis []plantModel.Wiki if err = global.DB.Find(&wikis).Error; err != nil { return err } embClient := getEmbeddingClient(cfg) conn, qdCtx, err := newQdrantConn(cfg) if err != nil { return err } defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) var successIds []string for _, w := range wikis { text := buildWikiText(w) embResp, embErr := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Input: []string{text}, Model: openai.EmbeddingModel(cfg.EmbeddingModelName), }) if embErr != nil { global.Logger.Error("Embedding failed", zap.String("wiki_id", w.Id), zap.Error(embErr)) continue } _, upsertErr := ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{ CollectionName: cfg.QdrantCollection, Points: []*qdrant.PointStruct{{ Id: qdrant.NewID(wikiToQdrantID(w.Id)), Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...), Payload: map[string]*qdrant.Value{ "wiki_id": qdrant.NewValueString(w.Id), "name": qdrant.NewValueString(w.Name), "full_text": qdrant.NewValueString(text), }, }}, }) if upsertErr != nil { global.Logger.Error("Qdrant upsert failed", zap.String("wiki_id", w.Id), zap.Error(upsertErr)) } else { successIds = append(successIds, w.Id) } } // 批量更新同步状态 if len(successIds) > 0 { _ = global.DB.Model(&plantModel.Wiki{}).Where("id IN ?", successIds).Update("is_vector_synced", 1).Error } global.Logger.Info("SyncWikiToQdrant done", zap.Int("total", len(wikis)), zap.Int("success", len(successIds))) return nil } // ────────────────────────────────────────────────────────────── // PlantChatStreamRAG 向量检索 + 大模型流式对话 // ────────────────────────────────────────────────────────────── func (s *AiRagService) PlantChatStreamRAG(ctx context.Context, userQuery string, onData func(chunk string) error) error { cfg, err := sysAiConfigService.GetActiveAiConfig() if err != nil { return err } embClient := getEmbeddingClient(cfg) chatClient := getChatClient(cfg) embResp, err := embClient.CreateEmbeddings(ctx, openai.EmbeddingRequest{ Input: []string{userQuery}, Model: openai.EmbeddingModel(cfg.EmbeddingModelName), }) if err != nil { return fmt.Errorf("向量化查询失败: %w", err) } conn, qdCtx, connErr := newQdrantConn(cfg) var contextText string if connErr == nil { defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) limit := uint64(3) searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{ CollectionName: cfg.QdrantCollection, Vector: embResp.Data[0].Embedding, Limit: limit, WithPayload: &qdrant.WithPayloadSelector{ SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}, }, }) if searchErr != nil { global.Logger.Warn("Qdrant search failed, using empty context", zap.Error(searchErr)) } else { for _, pt := range searchRes.GetResult() { if txt, ok := pt.GetPayload()["full_text"]; ok { contextText += txt.GetStringValue() + "\n" } } } } else { global.Logger.Warn("Qdrant connect failed, skipping RAG", zap.Error(connErr)) } systemPrompt := "你是一个专业的植物百科助手,请基于以下知识库信息回答用户问题。如果知识库无相关信息,结合你的通用知识作答。\n" if contextText != "" { systemPrompt += "--- 知识库 ---\n" + contextText + "\n--------------\n" } stream, err := chatClient.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ Model: cfg.ChatModelName, Messages: []openai.ChatCompletionMessage{ {Role: openai.ChatMessageRoleSystem, Content: systemPrompt}, {Role: openai.ChatMessageRoleUser, Content: userQuery}, }, Stream: true, }) if err != nil { return fmt.Errorf("大模型调用失败: %w", err) } defer stream.Close() for { resp, recvErr := stream.Recv() if errors.Is(recvErr, io.EOF) { break } if recvErr != nil { return recvErr } if len(resp.Choices) > 0 { if content := resp.Choices[0].Delta.Content; content != "" { if writeErr := onData(content); writeErr != nil { return writeErr } } } } return nil }