216 lines
6.4 KiB
Go
216 lines
6.4 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
qdrant "github.com/qdrant/go-client/qdrant"
|
|
openai "github.com/sashabaranov/go-openai"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
"sundynix-micro-go/app/plant/api/internal/svc"
|
|
plantModel "sundynix-micro-go/app/plant/model"
|
|
plantPb "sundynix-micro-go/app/plant/rpc/plant"
|
|
)
|
|
|
|
func getActiveAiConfig(svcCtx *svc.ServiceContext) (*plantModel.SysAiConfig, error) {
|
|
var cfg plantModel.SysAiConfig
|
|
err := svcCtx.DB.Where("is_active = 1").First(&cfg).Error
|
|
if err != nil {
|
|
return nil, errors.New("AI/RAG 问答服务暂未激活或数据库配置缺失")
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func chatModel(dbCfg *plantModel.SysAiConfig) string {
|
|
if dbCfg.ChatModelName != "" {
|
|
return dbCfg.ChatModelName
|
|
}
|
|
return "gpt-4o-mini"
|
|
}
|
|
|
|
func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
|
|
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
|
|
addr = strings.TrimPrefix(addr, "https://")
|
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
|
|
}
|
|
ctx := context.Background()
|
|
if cfg.QdrantApiKey != "" {
|
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
|
|
}
|
|
return conn, ctx, nil
|
|
}
|
|
|
|
func retrieveRAGContext(ctx context.Context, svcCtx *svc.ServiceContext, dbCfg *plantModel.SysAiConfig, question string) string {
|
|
if dbCfg.EmbeddingApiUrl == "" || dbCfg.EmbeddingApiKey == "" || dbCfg.QdrantUrl == "" || dbCfg.QdrantCollection == "" {
|
|
return ""
|
|
}
|
|
|
|
config := openai.DefaultConfig(dbCfg.EmbeddingApiKey)
|
|
if dbCfg.EmbeddingApiUrl != "" {
|
|
config.BaseURL = dbCfg.EmbeddingApiUrl
|
|
}
|
|
client := openai.NewClientWithConfig(config)
|
|
embResp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
|
Input: []string{question},
|
|
Model: openai.EmbeddingModel(dbCfg.EmbeddingModelName),
|
|
})
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
conn, qdCtx, connErr := newQdrantConn(dbCfg)
|
|
if connErr != nil {
|
|
return ""
|
|
}
|
|
defer conn.Close()
|
|
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{
|
|
CollectionName: dbCfg.QdrantCollection,
|
|
Vector: embResp.Data[0].Embedding,
|
|
Limit: 3,
|
|
WithPayload: &qdrant.WithPayloadSelector{
|
|
SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true},
|
|
},
|
|
})
|
|
if searchErr != nil {
|
|
return ""
|
|
}
|
|
|
|
var b strings.Builder
|
|
for _, pt := range searchRes.GetResult() {
|
|
if txt, ok := pt.GetPayload()["full_text"]; ok {
|
|
b.WriteString(txt.GetStringValue())
|
|
b.WriteString("\n")
|
|
}
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func SaveHistory(ctx context.Context, svcCtx *svc.ServiceContext, userID, question, answer string) {
|
|
if userID == "" || question == "" || answer == "" {
|
|
return
|
|
}
|
|
_, _ = svcCtx.PlantRpc.SaveAiChatHistory(ctx, &plantPb.SaveAiChatHistoryReq{
|
|
UserId: userID, Question: question, Answer: answer,
|
|
})
|
|
}
|
|
|
|
func ChatCompletion(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string) (string, error) {
|
|
dbCfg, err := getActiveAiConfig(svcCtx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。"
|
|
if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" {
|
|
systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------"
|
|
}
|
|
|
|
config := openai.DefaultConfig(dbCfg.ChatApiKey)
|
|
if dbCfg.ChatApiUrl != "" {
|
|
config.BaseURL = dbCfg.ChatApiUrl
|
|
}
|
|
client := openai.NewClientWithConfig(config)
|
|
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
|
Model: chatModel(dbCfg),
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
|
{Role: openai.ChatMessageRoleUser, Content: question},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 || resp.Choices[0].Message.Content == "" {
|
|
return "", errors.New("AI 响应为空")
|
|
}
|
|
answer := resp.Choices[0].Message.Content
|
|
SaveHistory(ctx, svcCtx, userID, question, answer)
|
|
return answer, nil
|
|
}
|
|
|
|
func StreamChat(ctx context.Context, svcCtx *svc.ServiceContext, userID, question string, w io.Writer) error {
|
|
dbCfg, err := getActiveAiConfig(svcCtx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := ensureQuota(ctx, svcCtx, userID, dbCfg); err != nil {
|
|
return err
|
|
}
|
|
|
|
systemPrompt := "你是一个专业的植物百科助手。回答规则:基于知识库信息回答,不够则结合通用知识;不要使用 Markdown;用纯文本分段;回答简洁专业、条理清晰。"
|
|
if ctxText := retrieveRAGContext(ctx, svcCtx, dbCfg, question); ctxText != "" {
|
|
systemPrompt += "\n--- 知识库 ---\n" + ctxText + "\n--------------"
|
|
}
|
|
|
|
config := openai.DefaultConfig(dbCfg.ChatApiKey)
|
|
if dbCfg.ChatApiUrl != "" {
|
|
config.BaseURL = dbCfg.ChatApiUrl
|
|
}
|
|
client := openai.NewClientWithConfig(config)
|
|
stream, err := client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
|
Model: chatModel(dbCfg),
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
|
{Role: openai.ChatMessageRoleUser, Content: question},
|
|
},
|
|
Stream: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stream.Close()
|
|
|
|
var answer strings.Builder
|
|
for {
|
|
resp, recvErr := stream.Recv()
|
|
if errors.Is(recvErr, io.EOF) {
|
|
break
|
|
}
|
|
if recvErr != nil {
|
|
return recvErr
|
|
}
|
|
if len(resp.Choices) > 0 {
|
|
content := resp.Choices[0].Delta.Content
|
|
if content != "" {
|
|
_, _ = fmt.Fprintf(w, "data: %s\n\n", content)
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
answer.WriteString(content)
|
|
}
|
|
}
|
|
}
|
|
|
|
SaveHistory(ctx, svcCtx, userID, question, answer.String())
|
|
return nil
|
|
}
|
|
|
|
func ensureQuota(ctx context.Context, svcCtx *svc.ServiceContext, userID string, dbCfg *plantModel.SysAiConfig) error {
|
|
if userID == "" {
|
|
return nil
|
|
}
|
|
quota, err := svcCtx.PlantRpc.GetAiChatQuota(ctx, &plantPb.GetProfileReq{UserId: userID})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
limit := int64(dbCfg.DailyQueryLimit)
|
|
if limit > 0 && quota.Remaining <= 0 {
|
|
return fmt.Errorf("今日问答次数已达上限(%d次),明天再来吧", limit)
|
|
}
|
|
return nil
|
|
}
|