feat: 植物识别百科ai助手迁移
This commit is contained in:
@@ -1,131 +1,101 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
qdrant "github.com/qdrant/go-client/qdrant"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"sundynix-micro-go/app/plant/api/internal/svc"
|
||||
plantModel "sundynix-micro-go/app/plant/model"
|
||||
plantPb "sundynix-micro-go/app/plant/rpc/plant"
|
||||
)
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
func getActiveAiConfig(svcCtx *svc.ServiceContext) (*plantModel.SysAiConfig, error) {
|
||||
var cfg plantModel.SysAiConfig
|
||||
err := svcCtx.DB.Where("is_active = 1").First(&cfg).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("AI/RAG 问答服务暂未激活或数据库配置缺失")
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
func chatModel(svcCtx *svc.ServiceContext) string {
|
||||
if svcCtx.Config.Ai.ChatModelName != "" {
|
||||
return svcCtx.Config.Ai.ChatModelName
|
||||
func chatModel(dbCfg *plantModel.SysAiConfig) string {
|
||||
if dbCfg.ChatModelName != "" {
|
||||
return dbCfg.ChatModelName
|
||||
}
|
||||
return "gpt-4o-mini"
|
||||
}
|
||||
|
||||
func requestBody(svcCtx *svc.ServiceContext, question string, stream bool) ([]byte, error) {
|
||||
systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。"
|
||||
if ctxText := retrieveRAGContext(context.Background(), svcCtx, question); ctxText != "" {
|
||||
systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------"
|
||||
func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
|
||||
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
|
||||
addr = strings.TrimPrefix(addr, "https://")
|
||||
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
|
||||
}
|
||||
return json.Marshal(chatRequest{
|
||||
Model: chatModel(svcCtx),
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: question},
|
||||
},
|
||||
Stream: stream,
|
||||
})
|
||||
ctx := context.Background()
|
||||
if cfg.QdrantApiKey != "" {
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
|
||||
}
|
||||
return conn, ctx, nil
|
||||
}
|
||||
|
||||
func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, question string) string {
|
||||
c := svcCtx.Config.Ai
|
||||
if c.EmbeddingApiUrl == "" || c.EmbeddingApiKey == "" || c.QdrantUrl == "" || c.QdrantCollection == "" {
|
||||
return ""
|
||||
}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": c.EmbeddingModelName,
|
||||
"input": question,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.EmbeddingApiUrl, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.EmbeddingApiKey)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var emb struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 || json.NewDecoder(resp.Body).Decode(&emb) != nil || len(emb.Data) == 0 {
|
||||
func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, dbCfg *plantModel.SysAiConfig, question string) string {
|
||||
if dbCfg.EmbeddingApiUrl == "" || dbCfg.EmbeddingApiKey == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
searchBody, _ := json.Marshal(map[string]interface{}{
|
||||
"vector": emb.Data[0].Embedding,
|
||||
"limit": 3,
|
||||
"with_payload": true,
|
||||
config := openai.DefaultConfig(dbCfg.EmbeddingApiKey)
|
||||
if dbCfg.EmbeddingApiUrl != "" {
|
||||
config.BaseURL = dbCfg.EmbeddingApiUrl
|
||||
}
|
||||
client := openai.NewClientWithConfig(config)
|
||||
embResp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
||||
Input: []string{question},
|
||||
Model: openai.EmbeddingModel(dbCfg.EmbeddingModelName),
|
||||
})
|
||||
searchReq, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.QdrantUrl, "/")+"/collections/"+c.QdrantCollection+"/points/search", bytes.NewReader(searchBody))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
searchReq.Header.Set("Content-Type", "application/json")
|
||||
if c.QdrantApiKey != "" {
|
||||
searchReq.Header.Set("api-key", c.QdrantApiKey)
|
||||
}
|
||||
searchResp, err := http.DefaultClient.Do(searchReq)
|
||||
if err != nil {
|
||||
|
||||
conn, qdCtx, connErr := newQdrantConn(dbCfg)
|
||||
if connErr != nil {
|
||||
return ""
|
||||
}
|
||||
defer searchResp.Body.Close()
|
||||
var parsed struct {
|
||||
Result []struct {
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if searchResp.StatusCode < 200 || searchResp.StatusCode >= 300 || json.NewDecoder(searchResp.Body).Decode(&parsed) != nil {
|
||||
defer conn.Close()
|
||||
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{
|
||||
CollectionName: dbCfg.QdrantCollection,
|
||||
Vector: embResp.Data[0].Embedding,
|
||||
Limit: 3,
|
||||
WithPayload: &qdrant.WithPayloadSelector{
|
||||
SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true},
|
||||
},
|
||||
})
|
||||
if searchErr != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for _, item := range parsed.Result {
|
||||
if text, ok := item.Payload["full_text"].(string); ok && text != "" {
|
||||
b.WriteString(text)
|
||||
for _, pt := range searchRes.GetResult() {
|
||||
if txt, ok := pt.GetPayload()["full_text"]; ok {
|
||||
b.WriteString(txt.GetStringValue())
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func newChatRequest(ctx context.Context, svcCtx *svc.ServiceContext, body []byte) (*http.Request, error) {
|
||||
if svcCtx.Config.Ai.ChatApiUrl == "" || svcCtx.Config.Ai.ChatApiKey == "" {
|
||||
return nil, errors.New("AI/RAG 未配置 ChatApiUrl 或 ChatApiKey")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, svcCtx.Config.Ai.ChatApiUrl, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+svcCtx.Config.Ai.ChatApiKey)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, question, answer string) {
|
||||
if userID == "" || question == "" || answer == "" {
|
||||
return
|
||||
@@ -136,102 +106,100 @@ func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, questi
|
||||
}
|
||||
|
||||
func ChatCompletion(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string) (string, error) {
|
||||
if err := ensureQuota(ctx, svcCtx, userID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
body, err := requestBody(svcCtx, question, false)
|
||||
dbCfg, err := getActiveAiConfig(svcCtx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := newChatRequest(ctx, svcCtx, body)
|
||||
if err != nil {
|
||||
if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return "", fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。"
|
||||
if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" {
|
||||
systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------"
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
|
||||
config := openai.DefaultConfig(dbCfg.ChatApiKey)
|
||||
if dbCfg.ChatApiUrl != "" {
|
||||
config.BaseURL = dbCfg.ChatApiUrl
|
||||
}
|
||||
client := openai.NewClientWithConfig(config)
|
||||
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||
Model: chatModel(dbCfg),
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
||||
{Role: openai.ChatMessageRoleUser, Content: question},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" {
|
||||
if len(resp.Choices) == 0 || resp.Choices[0].Message.Content == "" {
|
||||
return "", errors.New("AI 响应为空")
|
||||
}
|
||||
answer := parsed.Choices[0].Message.Content
|
||||
answer := resp.Choices[0].Message.Content
|
||||
SaveHistory(ctx, svcCtx, userID, question, answer)
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func StreamChat(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string, w io.Writer) error {
|
||||
if err := ensureQuota(ctx, svcCtx, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := requestBody(svcCtx, question, true)
|
||||
dbCfg, err := getActiveAiConfig(svcCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := newChatRequest(ctx, svcCtx, body)
|
||||
if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。"
|
||||
if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" {
|
||||
systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------"
|
||||
}
|
||||
|
||||
config := openai.DefaultConfig(dbCfg.ChatApiKey)
|
||||
if dbCfg.ChatApiUrl != "" {
|
||||
config.BaseURL = dbCfg.ChatApiUrl
|
||||
}
|
||||
client := openai.NewClientWithConfig(config)
|
||||
stream, err := client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: chatModel(dbCfg),
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
||||
{Role: openai.ChatMessageRoleUser, Content: question},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("AI 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var answer strings.Builder
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
_, _ = fmt.Fprintln(w, line)
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
for {
|
||||
resp, recvErr := stream.Recv()
|
||||
if errors.Is(recvErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if recvErr != nil {
|
||||
return recvErr
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
content := resp.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", content)
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
answer.WriteString(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var chunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &chunk) == nil && len(chunk.Choices) > 0 {
|
||||
answer.WriteString(chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
SaveHistory(ctx, svcCtx, userID, question, answer.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string) error {
|
||||
func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string, dbCfg *plantModel.SysAiConfig) error {
|
||||
if userID == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -239,8 +207,9 @@ func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if quota.Limit > 0 && quota.Remaining <= 0 {
|
||||
return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", quota.Limit)
|
||||
limit := int64(dbCfg.DailyQueryLimit)
|
||||
if limit > 0 && quota.Remaining <= 0 {
|
||||
return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", limit)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user