feat: 百科知识库存入向量
This commit is contained in:
@@ -0,0 +1,395 @@
|
||||
package plant
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
qdrant "github.com/qdrant/go-client/qdrant"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"sundynix-go/global"
|
||||
plantModel "sundynix-go/model/plant"
|
||||
"sundynix-go/model/system"
|
||||
systemService "sundynix-go/service/system"
|
||||
)
|
||||
|
||||
type AiRagService struct{}
|
||||
|
||||
var sysAiConfigService = systemService.SysAiConfigService{}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// OpenAI 客户端构建
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
|
||||
func getChatClient(cfg *system.SysAiConfig) *openai.Client {
|
||||
config := openai.DefaultConfig(cfg.ChatApiKey)
|
||||
if cfg.ChatApiUrl != "" {
|
||||
config.BaseURL = cfg.ChatApiUrl
|
||||
}
|
||||
return openai.NewClientWithConfig(config)
|
||||
}
|
||||
|
||||
func getEmbeddingClient(cfg *system.SysAiConfig) *openai.Client {
|
||||
config := openai.DefaultConfig(cfg.EmbeddingApiKey)
|
||||
if cfg.EmbeddingApiUrl != "" {
|
||||
config.BaseURL = cfg.EmbeddingApiUrl
|
||||
}
|
||||
return openai.NewClientWithConfig(config)
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// Qdrant gRPC 连接
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
|
||||
func newQdrantConn(cfg *system.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
|
||||
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
|
||||
addr = strings.TrimPrefix(addr, "https://")
|
||||
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
if cfg.QdrantApiKey != "" {
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
|
||||
}
|
||||
return conn, ctx, nil
|
||||
}
|
||||
|
||||
// EnsureCollection 确保 Collection 存在,不存在则创建
|
||||
func EnsureCollection(cfg *system.SysAiConfig) error {
|
||||
conn, ctx, err := newQdrantConn(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
dim := uint64(cfg.VectorDimension)
|
||||
if dim == 0 {
|
||||
dim = 104
|
||||
}
|
||||
collClient := qdrant.NewCollectionsClient(conn)
|
||||
if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil {
|
||||
return nil // 已存在
|
||||
}
|
||||
_, err = collClient.Create(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
VectorsConfig: &qdrant.VectorsConfig{
|
||||
Config: &qdrant.VectorsConfig_Params{
|
||||
Params: &qdrant.VectorParams{Size: dim, Distance: qdrant.Distance_Cosine},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant create collection failed: %w", err)
|
||||
}
|
||||
global.Logger.Info("Qdrant collection created", zap.String("collection", cfg.QdrantCollection))
|
||||
return nil
|
||||
}
|
||||
|
||||
// wikiID → Qdrant point UUID(确保幂等)
|
||||
func wikiToQdrantID(wikiId string) string {
|
||||
return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiId)).String()
|
||||
}
|
||||
|
||||
// buildWikiText 拼接用于向量化的文本语料
|
||||
func buildWikiText(w plantModel.Wiki) string {
|
||||
return fmt.Sprintf(
|
||||
"植物名字:%s. 拉丁名:%s. 科属:%s. 生命周期:%s. 生长习性:%s. 病虫害:%s. 光照类型:%s. 最佳温度:%s.",
|
||||
w.Name, w.LatinName, w.Genus, w.LifeCycle, w.GrowthHabit,
|
||||
w.PestsDiseases, w.LightType, w.OptimalTempPeriod,
|
||||
)
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// SyncSingleWikiAsync 异步同步单条百科到 Qdrant(新增/更新时调用)
|
||||
// 同步成功后将 is_vector_synced 置为 1
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
func (s *AiRagService) SyncSingleWikiAsync(wikiId string) {
|
||||
go func() {
|
||||
if err := s.syncSingleWiki(wikiId); err != nil {
|
||||
global.Logger.Error("Async sync wiki to Qdrant failed", zap.String("wiki_id", wikiId), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncSingleWiki 同步同步单条百科到 Qdrant(用于API直接调用)
|
||||
func (s *AiRagService) SyncSingleWiki(wikiId string) error {
|
||||
return s.syncSingleWiki(wikiId)
|
||||
}
|
||||
|
||||
func (s *AiRagService) syncSingleWiki(wikiId string) error {
|
||||
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("no active ai config: %w", err)
|
||||
}
|
||||
if err = EnsureCollection(cfg); err != nil {
|
||||
global.Logger.Warn("EnsureCollection warn", zap.Error(err))
|
||||
}
|
||||
|
||||
var w plantModel.Wiki
|
||||
if err = global.DB.Where("id = ?", wikiId).First(&w).Error; err != nil {
|
||||
return fmt.Errorf("wiki not found: %w", err)
|
||||
}
|
||||
|
||||
text := buildWikiText(w)
|
||||
embClient := getEmbeddingClient(cfg)
|
||||
embResp, err := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
||||
Input: []string{text},
|
||||
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("embedding failed: %w", err)
|
||||
}
|
||||
|
||||
conn, qdCtx, err := newQdrantConn(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
|
||||
_, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
Points: []*qdrant.PointStruct{{
|
||||
Id: qdrant.NewID(wikiToQdrantID(wikiId)),
|
||||
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
|
||||
Payload: map[string]*qdrant.Value{
|
||||
"wiki_id": qdrant.NewValueString(w.Id),
|
||||
"name": qdrant.NewValueString(w.Name),
|
||||
"full_text": qdrant.NewValueString(text),
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant upsert failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新同步状态
|
||||
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 1).Error
|
||||
global.Logger.Info("Wiki synced to Qdrant", zap.String("wiki_id", wikiId))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// DeleteFromQdrant 从 Qdrant 删除单条植物的向量点位
|
||||
// 删除成功后将 is_vector_synced 置为 0
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
func (s *AiRagService) DeleteFromQdrant(wikiId string) error {
|
||||
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("no active ai config: %w", err)
|
||||
}
|
||||
conn, qdCtx, err := newQdrantConn(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
|
||||
qID := wikiToQdrantID(wikiId)
|
||||
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
Points: &qdrant.PointsSelector{
|
||||
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
|
||||
Points: &qdrant.PointsIdsList{
|
||||
Ids: []*qdrant.PointId{qdrant.NewID(qID)},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("qdrant delete failed: %w", err)
|
||||
}
|
||||
_ = global.DB.Model(&plantModel.Wiki{}).Where("id = ?", wikiId).Update("is_vector_synced", 0).Error
|
||||
global.Logger.Info("Wiki deleted from Qdrant", zap.String("wiki_id", wikiId))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteFromQdrantBatch 批量从 Qdrant 删除(用于批量删除百科时)
|
||||
func (s *AiRagService) DeleteFromQdrantBatch(wikiIds []string) {
|
||||
go func() {
|
||||
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
||||
if err != nil {
|
||||
global.Logger.Warn("No active ai config for batch qdrant delete", zap.Error(err))
|
||||
return
|
||||
}
|
||||
conn, qdCtx, err := newQdrantConn(cfg)
|
||||
if err != nil {
|
||||
global.Logger.Warn("Qdrant connect failed for batch delete", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
|
||||
var ids []*qdrant.PointId
|
||||
for _, wid := range wikiIds {
|
||||
ids = append(ids, qdrant.NewID(wikiToQdrantID(wid)))
|
||||
}
|
||||
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
Points: &qdrant.PointsSelector{
|
||||
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
|
||||
Points: &qdrant.PointsIdsList{Ids: ids},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
global.Logger.Error("Qdrant batch delete failed", zap.Error(err))
|
||||
} else {
|
||||
global.Logger.Info("Qdrant batch delete done", zap.Int("count", len(wikiIds)))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// SyncWikiToQdrant 全量同步(后台操作/手动触发)
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
func (s *AiRagService) SyncWikiToQdrant() error {
|
||||
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = EnsureCollection(cfg); err != nil {
|
||||
global.Logger.Warn("EnsureCollection failed, continuing", zap.Error(err))
|
||||
}
|
||||
|
||||
var wikis []plantModel.Wiki
|
||||
if err = global.DB.Find(&wikis).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embClient := getEmbeddingClient(cfg)
|
||||
conn, qdCtx, err := newQdrantConn(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
|
||||
var successIds []string
|
||||
for _, w := range wikis {
|
||||
text := buildWikiText(w)
|
||||
embResp, embErr := embClient.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
||||
Input: []string{text},
|
||||
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
||||
})
|
||||
if embErr != nil {
|
||||
global.Logger.Error("Embedding failed", zap.String("wiki_id", w.Id), zap.Error(embErr))
|
||||
continue
|
||||
}
|
||||
_, upsertErr := ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
Points: []*qdrant.PointStruct{{
|
||||
Id: qdrant.NewID(wikiToQdrantID(w.Id)),
|
||||
Vectors: qdrant.NewVectors(embResp.Data[0].Embedding...),
|
||||
Payload: map[string]*qdrant.Value{
|
||||
"wiki_id": qdrant.NewValueString(w.Id),
|
||||
"name": qdrant.NewValueString(w.Name),
|
||||
"full_text": qdrant.NewValueString(text),
|
||||
},
|
||||
}},
|
||||
})
|
||||
if upsertErr != nil {
|
||||
global.Logger.Error("Qdrant upsert failed", zap.String("wiki_id", w.Id), zap.Error(upsertErr))
|
||||
} else {
|
||||
successIds = append(successIds, w.Id)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量更新同步状态
|
||||
if len(successIds) > 0 {
|
||||
_ = global.DB.Model(&plantModel.Wiki{}).Where("id IN ?", successIds).Update("is_vector_synced", 1).Error
|
||||
}
|
||||
global.Logger.Info("SyncWikiToQdrant done", zap.Int("total", len(wikis)), zap.Int("success", len(successIds)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
// PlantChatStreamRAG 向量检索 + 大模型流式对话
|
||||
// ──────────────────────────────────────────────────────────────
|
||||
func (s *AiRagService) PlantChatStreamRAG(ctx context.Context, userQuery string, onData func(chunk string) error) error {
|
||||
cfg, err := sysAiConfigService.GetActiveAiConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embClient := getEmbeddingClient(cfg)
|
||||
chatClient := getChatClient(cfg)
|
||||
|
||||
embResp, err := embClient.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
||||
Input: []string{userQuery},
|
||||
Model: openai.EmbeddingModel(cfg.EmbeddingModelName),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("向量化查询失败: %w", err)
|
||||
}
|
||||
|
||||
conn, qdCtx, connErr := newQdrantConn(cfg)
|
||||
var contextText string
|
||||
if connErr == nil {
|
||||
defer conn.Close()
|
||||
ptsClient := qdrant.NewPointsClient(conn)
|
||||
limit := uint64(3)
|
||||
searchRes, searchErr := ptsClient.Search(qdCtx, &qdrant.SearchPoints{
|
||||
CollectionName: cfg.QdrantCollection,
|
||||
Vector: embResp.Data[0].Embedding,
|
||||
Limit: limit,
|
||||
WithPayload: &qdrant.WithPayloadSelector{
|
||||
SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true},
|
||||
},
|
||||
})
|
||||
if searchErr != nil {
|
||||
global.Logger.Warn("Qdrant search failed, using empty context", zap.Error(searchErr))
|
||||
} else {
|
||||
for _, pt := range searchRes.GetResult() {
|
||||
if txt, ok := pt.GetPayload()["full_text"]; ok {
|
||||
contextText += txt.GetStringValue() + "\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
global.Logger.Warn("Qdrant connect failed, skipping RAG", zap.Error(connErr))
|
||||
}
|
||||
|
||||
systemPrompt := "你是一个专业的植物百科助手,请基于以下知识库信息回答用户问题。如果知识库无相关信息,结合你的通用知识作答。\n"
|
||||
if contextText != "" {
|
||||
systemPrompt += "--- 知识库 ---\n" + contextText + "\n--------------\n"
|
||||
}
|
||||
|
||||
stream, err := chatClient.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: cfg.ChatModelName,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
||||
{Role: openai.ChatMessageRoleUser, Content: userQuery},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("大模型调用失败: %w", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
resp, recvErr := stream.Recv()
|
||||
if errors.Is(recvErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if recvErr != nil {
|
||||
return recvErr
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
if content := resp.Choices[0].Delta.Content; content != "" {
|
||||
if writeErr := onData(content); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -12,4 +12,5 @@ type ServiceGroup struct {
|
||||
UserProfileService
|
||||
CallbackService
|
||||
ExchangeService
|
||||
AiRagService
|
||||
}
|
||||
|
||||
+17
-2
@@ -11,13 +11,15 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var aiRagService = AiRagService{}
|
||||
|
||||
type WikiService struct{}
|
||||
|
||||
var WikiServiceApp = new(WikiClassService)
|
||||
|
||||
// CreateWiki 创建百科
|
||||
func (s *WikiService) CreateWiki(req plantReq.CreateWiki) error {
|
||||
return global.DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := global.DB.Transaction(func(tx *gorm.DB) error {
|
||||
//1.先模糊查询name是否存在 如果存在 则返回错误
|
||||
if !errors.Is(tx.Where("name like ?", "%"+req.Name+"%").First(&plant.Wiki{}).Error, gorm.ErrRecordNotFound) {
|
||||
return errors.New("植物已经存在")
|
||||
@@ -117,6 +119,14 @@ func (s *WikiService) CreateWiki(req plantReq.CreateWiki) error {
|
||||
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
// 异步同步到 Qdrant(事务提交后,wiki.Id 已可用)
|
||||
var created plant.Wiki
|
||||
if dbErr := global.DB.Where("name = ?", req.Name).First(&created).Error; dbErr == nil {
|
||||
aiRagService.SyncSingleWikiAsync(created.Id)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateWiki 修改百科
|
||||
@@ -288,7 +298,7 @@ func (s *WikiService) UploadImg(req common.UploadOss) error {
|
||||
|
||||
// DeleteWiki 删除百科
|
||||
func (s *WikiService) DeleteWiki(req common.IdsReq) error {
|
||||
return global.DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := global.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var imgIds []string
|
||||
tx.Table("sundynix_wiki_oss").Where("wiki_id IN ?", req.Ids).Pluck("oss_id", &imgIds)
|
||||
// 3. 物理删除图片记录本身
|
||||
@@ -311,4 +321,9 @@ func (s *WikiService) DeleteWiki(req common.IdsReq) error {
|
||||
//删除百科本身
|
||||
return tx.Unscoped().Where("id IN ?", req.Ids).Delete(&plant.Wiki{}).Error
|
||||
})
|
||||
if err == nil {
|
||||
// 异步清理 Qdrant 向量点位
|
||||
aiRagService.DeleteFromQdrantBatch(req.Ids)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user