package ai import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "sundynix-micro-go/app/plant/api/internal/svc" plantPb "sundynix-micro-go/app/plant/rpc/plant" ) type chatMessage struct { Role string `json:"role"` Content string `json:"content"` } type chatRequest struct { Model string `json:"model,omitempty"` Messages []chatMessage `json:"messages"` Stream bool `json:"stream"` } func chatModel(svcCtx *svc.ServiceContext) string { if svcCtx.Config.Ai.ChatModelName != "" { return svcCtx.Config.Ai.ChatModelName } return "gpt-4o-mini" } func requestBody(svcCtx *svc.ServiceContext, question string, stream bool) ([]byte, error) { systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。" if ctxText := retrieveRAGContext(context.Background(), svcCtx, question); ctxText != "" { systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------" } return json.Marshal(chatRequest{ Model: chatModel(svcCtx), Messages: []chatMessage{ {Role: "system", Content: systemPrompt}, {Role: "user", Content: question}, }, Stream: stream, }) } func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, question string) string { c := svcCtx.Config.Ai if c.EmbeddingApiUrl == "" || c.EmbeddingApiKey == "" || c.QdrantUrl == "" || c.QdrantCollection == "" { return "" } body, _ := json.Marshal(map[string]interface{}{ "model": c.EmbeddingModelName, "input": question, }) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.EmbeddingApiUrl, bytes.NewReader(body)) if err != nil { return "" } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.EmbeddingApiKey) resp, err := http.DefaultClient.Do(req) if err != nil { return "" } defer resp.Body.Close() var emb struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` } if resp.StatusCode < 200 || resp.StatusCode >= 300 || json.NewDecoder(resp.Body).Decode(&emb) != nil || len(emb.Data) == 0 { return "" } searchBody, _ := json.Marshal(map[string]interface{}{ "vector": emb.Data[0].Embedding, "limit": 3, "with_payload": true, }) searchReq, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.QdrantUrl, "/")+"/collections/"+c.QdrantCollection+"/points/search", bytes.NewReader(searchBody)) if err != nil { return "" } searchReq.Header.Set("Content-Type", "application/json") if c.QdrantApiKey != "" { searchReq.Header.Set("api-key", c.QdrantApiKey) } searchResp, err := http.DefaultClient.Do(searchReq) if err != nil { return "" } defer searchResp.Body.Close() var parsed struct { Result []struct { Payload map[string]interface{} `json:"payload"` } `json:"result"` } if searchResp.StatusCode < 200 || searchResp.StatusCode >= 300 || json.NewDecoder(searchResp.Body).Decode(&parsed) != nil { return "" } var b strings.Builder for _, item := range parsed.Result { if text, ok := item.Payload["full_text"].(string); ok && text != "" { b.WriteString(text) b.WriteString("\n") } } return b.String() } func newChatRequest(ctx context.Context, svcCtx *svc.ServiceContext, body []byte) (*http.Request, error) { if svcCtx.Config.Ai.ChatApiUrl == "" || svcCtx.Config.Ai.ChatApiKey == "" { return nil, errors.New("AI/RAG 未配置 ChatApiUrl 或 ChatApiKey") } req, err := http.NewRequestWithContext(ctx, http.MethodPost, svcCtx.Config.Ai.ChatApiUrl, bytes.NewReader(body)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+svcCtx.Config.Ai.ChatApiKey) return req, nil } func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, question, answer string) { if userID == "" || question == "" || answer == "" { return } _, _ = svcCtx.PlantRpc.SaveAiChatHistory(ctx, &plantPb.SaveAiChatHistoryReq{ UserId: userID, Question: question, Answer: answer, }) } func ChatCompletion(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string) (string, error) { if err := ensureQuota(ctx, svcCtx, userID); err != nil { return "", err } body, err := requestBody(svcCtx, question, false) if err != nil { return "", err } req, err := newChatRequest(ctx, svcCtx, body) if err != nil { return "", err } resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return "", fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) } var parsed struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { return "", err } if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" { return "", errors.New("AI 响应为空") } answer := parsed.Choices[0].Message.Content SaveHistory(ctx, svcCtx, userID, question, answer) return answer, nil } func StreamChat(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string, w io.Writer) error { if err := ensureQuota(ctx, svcCtx, userID); err != nil { return err } body, err := requestBody(svcCtx, question, true) if err != nil { return err } req, err := newChatRequest(ctx, svcCtx, body) if err != nil { return err } resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw))) } var answer strings.Builder scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() _, _ = fmt.Fprintln(w, line) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) if data == "[DONE]" { continue } var chunk struct { Choices []struct { Delta struct { Content string `json:"content"` } `json:"delta"` } `json:"choices"` } if json.Unmarshal([]byte(data), &chunk) == nil && len(chunk.Choices) > 0 { answer.WriteString(chunk.Choices[0].Delta.Content) } } if err := scanner.Err(); err != nil { return err } SaveHistory(ctx, svcCtx, userID, question, answer.String()) return nil } func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string) error { if userID == "" { return nil } quota, err := svcCtx.PlantRpc.GetAiChatQuota(ctx, &plantPb.GetProfileReq{UserId: userID}) if err != nil { return err } if quota.Limit > 0 && quota.Remaining <= 0 { return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", quota.Limit) } return nil }