Files
2026-04-23 11:15:58 +08:00

396 lines
14 KiB
Go

package plant
import (
"context"
"errors"
"fmt"
"io"
"strings"
"github.com/google/uuid"
qdrant "github.com/qdrant/go-client/qdrant"
openai "github.com/sashabaranov/go-openai"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"sundynix-go/global"
plantModel "sundynix-go/model/plant"
"sundynix-go/model/system"
systemService "sundynix-go/service/system"
)
type AiRagService struct{}
var sysAiConfigService = systemService.SysAiConfigService{}
// ──────────────────────────────────────────────────────────────
// OpenAI 客户端构建
// ──────────────────────────────────────────────────────────────
func getChatClient(cfg *system.SysAiConfig) *openai.Client {
config := openai.DefaultConfig(cfg.ChatApiKey)
if cfg.ChatApiUrl != "" {
config.BaseURL = cfg.ChatApiUrl
}
return openai.NewClientWithConfig(config)
}
func getEmbeddingClient(cfg *system.SysAiConfig) *openai.Client {
config := openai.DefaultConfig(cfg.EmbeddingApiKey)
if cfg.EmbeddingApiUrl != "" {
config.BaseURL = cfg.EmbeddingApiUrl
}
return openai.NewClientWithConfig(config)
}
// ──────────────────────────────────────────────────────────────
// Qdrant gRPC 连接
// ──────────────────────────────────────────────────────────────
func newQdrantConn(cfg *system.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
addr = strings.TrimPrefix(addr, "https://")
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
}
ctx := context.Background()
if cfg.QdrantApiKey != "" {
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
}
return conn, ctx, nil
}
// EnsureCollection 确保 Collection 存在,不存在则创建
func EnsureCollection(cfg *system.SysAiConfig) error {
conn, ctx, err := newQdrantConn(cfg)
if err != nil {
return err
}
defer conn.Close()
dim := uint64(cfg.VectorDimension)
if dim == 0 {
dim = 104
}
collClient := qdrant.NewCollectionsClient(conn)
if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil {
return nil // 已存在
}
_, err = collClient.Create(ctx, &qdrant.CreateCollection{
CollectionName: cfg.QdrantCollection,
VectorsConfig: &qdrant.VectorsConfig{
Config: &qdrant.VectorsConfig_Params{
Params: &qdrant.VectorParams{Size: dim, Distance: qdrant.Distance_Cosine},
},
},
})
if err != nil {
return fmt.Errorf("qdrant create collection failed: %w", err)
}
global.Logger.Info("Qdrant collection created", zap.String("collection", cfg.QdrantCollection))
return nil
}
// wikiID → Qdrant point UUID(确保幂等)
func wikiToQdrantID(wikiId string) string {
return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiId)).String()
}
// buildWikiText 拼接用于向量化的文本语料
func buildWikiText(w plantModel.Wiki) string {
return fmt.Sprintf(
"植物名字:%s. 拉丁名:%s. 科属:%s. 生命周期:%s. 生长习性:%s. 病虫害:%s. 光照类型:%s. 最佳温度:%s.",
w.Name, w.LatinName, w.Genus, w.LifeCycle, w.GrowthHabit,
w.PestsDiseases, w.LightType, w.OptimalTempPeriod,
)
}
// ──────────────────────────────────────────────────────────────
// SyncSingleWikiAsync 异步同步单条百科到 Qdrant(新增/更新时调用)
// 同步成功后将 is_vector_synced 置为 1
// ──────────────────────────────────────────────────────────────
func (s *AiRagService) SyncSingleWikiAsync(wikiId string) {
go func() {
if err := s.syncSingleWiki(wikiId); err != nil {
global.Logger.Error("Async sync wiki to Qdrant failed", zap.String("wiki_id", wikiId), zap.Error(err))
}
}()
}
// SyncSingleWiki 同步同步单条百科到 Qdrant(用于API直接调用)
func (s *AiRagService) SyncSingleWiki(wikiId string) error {
return s.syncSingleWiki(wikiId)
}
func (s *AiRagService) syncSingleWiki(wikiId string) error {
cfg, err := sysAiConfigService.GetActiveAiConfig()
if err != nil {
return fmt.Errorf("no active ai config: %w", err)
}
if err = EnsureCollection(cfg); err != nil {
global.Logger.Warn("EnsureCollection warn", zap.Error(err))
}
var w plantModel.Wiki
if err = global.DB.Where("id = ?", wikiId).First(&w).Error; err != nil {
return fmt.Errorf("wiki not found: %w", err)
}
text := buildWikiText(w)
embClient := getEmbeddingClient(cfg)
embResp, err := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
Input: []string{text},
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
})
if err != nil {
return fmt.Errorf("embedding failed: %w", err)
}
conn, qdCtx, err := newQdrantConn(cfg)
if err != nil {
return err
}
defer conn.Close()
ptsClient := qdrant.NewPointsClient(conn)
_, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
CollectionName: cfg.QdrantCollection,
Points: []*qdrant.PointStruct{{
Id: qdrant.NewID(wikiToQdrantID(wikiId)),
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
Payload: map[string]*qdrant.Value{
"wiki_id": qdrant.NewValueString(w.Id),
"name": qdrant.NewValueString(w.Name),
"full_text": qdrant.NewValueString(text),
},
}},
})
if err != nil {
return fmt.Errorf("qdrant upsert failed: %w", err)
}
// 更新同步状态
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 1).Error
global.Logger.Info("Wiki synced to Qdrant", zap.String("wiki_id", wikiId))
return nil
}
// ──────────────────────────────────────────────────────────────
// DeleteFromQdrant 从 Qdrant 删除单条植物的向量点位
// 删除成功后将 is_vector_synced 置为 0
// ──────────────────────────────────────────────────────────────
func (s *AiRagService) DeleteFromQdrant(wikiId string) error {
cfg, err := sysAiConfigService.GetActiveAiConfig()
if err != nil {
return fmt.Errorf("no active ai config: %w", err)
}
conn, qdCtx, err := newQdrantConn(cfg)
if err != nil {
return err
}
defer conn.Close()
ptsClient := qdrant.NewPointsClient(conn)
qID := wikiToQdrantID(wikiId)
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
CollectionName: cfg.QdrantCollection,
Points: &qdrant.PointsSelector{
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
Points: &qdrant.PointsIdsList{
Ids: []*qdrant.PointId{qdrant.NewID(qID)},
},
},
},
})
if err != nil {
return fmt.Errorf("qdrant delete failed: %w", err)
}
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 0).Error
global.Logger.Info("Wiki deleted from Qdrant", zap.String("wiki_id", wikiId))
return nil
}
// DeleteFromQdrantBatch 批量从 Qdrant 删除(用于批量删除百科时)
func (s *AiRagService) DeleteFromQdrantBatch(wikiIds []string) {
go func() {
cfg, err := sysAiConfigService.GetActiveAiConfig()
if err != nil {
global.Logger.Warn("No active ai config for batch qdrant delete", zap.Error(err))
return
}
conn, qdCtx, err := newQdrantConn(cfg)
if err != nil {
global.Logger.Warn("Qdrant connect failed for batch delete", zap.Error(err))
return
}
defer conn.Close()
ptsClient := qdrant.NewPointsClient(conn)
var ids []*qdrant.PointId
for _, wid := range wikiIds {
ids = append(ids, qdrant.NewID(wikiToQdrantID(wid)))
}
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
CollectionName: cfg.QdrantCollection,
Points: &qdrant.PointsSelector{
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
Points: &qdrant.PointsIdsList{Ids: ids},
},
},
})
if err != nil {
global.Logger.Error("Qdrant batch delete failed", zap.Error(err))
} else {
global.Logger.Info("Qdrant batch delete done", zap.Int("count", len(wikiIds)))
}
}()
}
// ──────────────────────────────────────────────────────────────
// SyncWikiToQdrant 全量同步(后台操作/手动触发)
// ──────────────────────────────────────────────────────────────
func (s *AiRagService) SyncWikiToQdrant() error {
cfg, err := sysAiConfigService.GetActiveAiConfig()
if err != nil {
return err
}
if err = EnsureCollection(cfg); err != nil {
global.Logger.Warn("EnsureCollection failed, continuing", zap.Error(err))
}
var wikis []plantModel.Wiki
if err = global.DB.Find(&wikis).Error; err != nil {
return err
}
embClient := getEmbeddingClient(cfg)
conn, qdCtx, err := newQdrantConn(cfg)
if err != nil {
return err
}
defer conn.Close()
ptsClient := qdrant.NewPointsClient(conn)
var successIds []string
for _, w := range wikis {
text := buildWikiText(w)
embResp, embErr := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
Input: []string{text},
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
})
if embErr != nil {
global.Logger.Error("Embedding failed", zap.String("wiki_id", w.Id), zap.Error(embErr))
continue
}
_, upsertErr := ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
CollectionName: cfg.QdrantCollection,
Points: []*qdrant.PointStruct{{
Id: qdrant.NewID(wikiToQdrantID(w.Id)),
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
Payload: map[string]*qdrant.Value{
"wiki_id": qdrant.NewValueString(w.Id),
"name": qdrant.NewValueString(w.Name),
"full_text": qdrant.NewValueString(text),
},
}},
})
if upsertErr != nil {
global.Logger.Error("Qdrant upsert failed", zap.String("wiki_id", w.Id), zap.Error(upsertErr))
} else {
successIds = append(successIds, w.Id)
}
}
// 批量更新同步状态
if len(successIds) > 0 {
_ = global.DB.Model(&plantModel.Wiki{}).Where("id IN ?", successIds).Update("is_vector_synced", 1).Error
}
global.Logger.Info("SyncWikiToQdrant done", zap.Int("total", len(wikis)), zap.Int("success", len(successIds)))
return nil
}
// ──────────────────────────────────────────────────────────────
// PlantChatStreamRAG 向量检索 + 大模型流式对话
// ──────────────────────────────────────────────────────────────
func (s *AiRagService) PlantChatStreamRAG(ctx context.Context, userQuery string, onData func(chunk string) error) error {
cfg, err := sysAiConfigService.GetActiveAiConfig()
if err != nil {
return err
}
embClient := getEmbeddingClient(cfg)
chatClient := getChatClient(cfg)
embResp, err := embClient.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: []string{userQuery},
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
})
if err != nil {
return fmt.Errorf("向量化查询失败: %w", err)
}
conn, qdCtx, connErr := newQdrantConn(cfg)
var contextText string
if connErr == nil {
defer conn.Close()
ptsClient := qdrant.NewPointsClient(conn)
limit := uint64(3)
searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{
CollectionName: cfg.QdrantCollection,
Vector: embResp.Data[0].Embedding,
Limit: limit,
WithPayload: &qdrant.WithPayloadSelector{
SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true},
},
})
if searchErr != nil {
global.Logger.Warn("Qdrant search failed, using empty context", zap.Error(searchErr))
} else {
for _, pt := range searchRes.GetResult() {
if txt, ok := pt.GetPayload()["full_text"]; ok {
contextText += txt.GetStringValue() + "\n"
}
}
}
} else {
global.Logger.Warn("Qdrant connect failed, skipping RAG", zap.Error(connErr))
}
systemPrompt := "你是一个专业的植物百科助手。回答规则:1.基于知识库信息回答,不够则结合通用知识。2.严禁使用Markdown语法(不要用#、*、-、```等符号)。3.用纯文本回答,段落之间空一行。4.分类用「一、二、三」或emoji开头,重点内容直接加书名号《》或【】强调。5.回答简洁专业、条理清晰。\n"
if contextText != "" {
systemPrompt += "--- 知识库 ---\n" + contextText + "\n--------------\n"
}
stream, err := chatClient.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
Model: cfg.ChatModelName,
Messages: []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
{Role: openai.ChatMessageRoleUser, Content: userQuery},
},
Stream: true,
})
if err != nil {
return fmt.Errorf("大模型调用失败: %w", err)
}
defer stream.Close()
for {
resp, recvErr := stream.Recv()
if errors.Is(recvErr, io.EOF) {
break
}
if recvErr != nil {
return recvErr
}
if len(resp.Choices) > 0 {
if content := resp.Choices[0].Delta.Content; content != "" {
if writeErr := onData(content); writeErr != nil {
return writeErr
}
}
}
}
return nil
}