package ai import ( "context" "errors" "fmt" "io" "net/http" "strings" qdrant "github.com/qdrant/go-client/qdrant" openai "github.com/sashabaranov/go-openai" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "sundynix-micro-go/app/plant/api/internal/svc" plantModel "sundynix-micro-go/app/plant/model" plantPb "sundynix-micro-go/app/plant/rpc/plant" ) func getActiveAiConfig(svcCtx *svc.ServiceContext) (*plantModel.SysAiConfig, error) { var cfg plantModel.SysAiConfig err := svcCtx.DB.Where("is_active = 1").First(&cfg).Error if err != nil { return nil, errors.New("AI/RAG 问答服务暂未激活或数据库配置缺失") } return &cfg, nil } func chatModel(dbCfg *plantModel.SysAiConfig) string { if dbCfg.ChatModelName != "" { return dbCfg.ChatModelName } return "gpt-4o-mini" } func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) { addr := strings.TrimPrefix(cfg.QdrantUrl, "http://") addr = strings.TrimPrefix(addr, "https://") conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err) } ctx := context.Background() if cfg.QdrantApiKey != "" { ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey)) } return conn, ctx, nil } func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, dbCfg *plantModel.SysAiConfig, question string) string { if dbCfg.EmbeddingApiUrl == "" || dbCfg.EmbeddingApiKey == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" { return "" } config := openai.DefaultConfig(dbCfg.EmbeddingApiKey) if dbCfg.EmbeddingApiUrl != "" { config.BaseURL = dbCfg.EmbeddingApiUrl } client := openai.NewClientWithConfig(config) embResp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ Input: []string{question}, Model: openai.EmbeddingModel(dbCfg.EmbeddingModelName), }) if err != nil { return "" } conn, qdCtx, connErr := newQdrantConn(dbCfg) if connErr != nil { return "" } defer conn.Close() ptsClient := qdrant.NewPointsClient(conn) searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{ CollectionName: dbCfg.QdrantCollection, Vector: embResp.Data[0].Embedding, Limit: 3, WithPayload: &qdrant.WithPayloadSelector{ SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}, }, }) if searchErr != nil { return "" } var b strings.Builder for _, pt := range searchRes.GetResult() { if txt, ok := pt.GetPayload()["full_text"]; ok { b.WriteString(txt.GetStringValue()) b.WriteString("\n") } } return b.String() } func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, question, answer string) { if userID == "" || question == "" || answer == "" { return } _, _ = svcCtx.PlantRpc.SaveAiChatHistory(ctx, &plantPb.SaveAiChatHistoryReq{ UserId: userID, Question: question, Answer: answer, }) } func ChatCompletion(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string) (string, error) { dbCfg, err := getActiveAiConfig(svcCtx) if err != nil { return "", err } if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil { return "", err } systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" { systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" } config := openai.DefaultConfig(dbCfg.ChatApiKey) if dbCfg.ChatApiUrl != "" { config.BaseURL = dbCfg.ChatApiUrl } client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: chatModel(dbCfg), Messages: []openai.ChatCompletionMessage{ {Role: openai.ChatMessageRoleSystem, Content: systemPrompt}, {Role: openai.ChatMessageRoleUser, Content: question}, }, }) if err != nil { return "", err } if len(resp.Choices) == 0 || resp.Choices[0].Message.Content == "" { return "", errors.New("AI 响应为空") } answer := resp.Choices[0].Message.Content SaveHistory(ctx, svcCtx, userID, question, answer) return answer, nil } func StreamChat(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string, w io.Writer) error { dbCfg, err := getActiveAiConfig(svcCtx) if err != nil { return err } if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil { return err } systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" { systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" } config := openai.DefaultConfig(dbCfg.ChatApiKey) if dbCfg.ChatApiUrl != "" { config.BaseURL = dbCfg.ChatApiUrl } client := openai.NewClientWithConfig(config) stream, err := client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ Model: chatModel(dbCfg), Messages: []openai.ChatCompletionMessage{ {Role: openai.ChatMessageRoleSystem, Content: systemPrompt}, {Role: openai.ChatMessageRoleUser, Content: question}, }, Stream: true, }) if err != nil { return err } defer stream.Close() var answer strings.Builder for { resp, recvErr := stream.Recv() if errors.Is(recvErr, io.EOF) { break } if recvErr != nil { return recvErr } if len(resp.Choices) > 0 { content := resp.Choices[0].Delta.Content if content != "" { _, _ = fmt.Fprintf(w, "data: %s\n\n", content) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } answer.WriteString(content) } } } SaveHistory(ctx, svcCtx, userID, question, answer.String()) return nil } func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string, dbCfg *plantModel.SysAiConfig) error { if userID == "" { return nil } quota, err := svcCtx.PlantRpc.GetAiChatQuota(ctx, &plantPb.GetProfileReq{UserId: userID}) if err != nil { return err } limit := int64(dbCfg.DailyQueryLimit) if limit > 0 && quota.Remaining <= 0 { return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", limit) } return nil }