159 lines
4.8 KiB
Go
159 lines
4.8 KiB
Go
package logic
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
qdrant "github.com/qdrant/go-client/qdrant"
|
|
openai "github.com/sashabaranov/go-openai"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
"gorm.io/gorm"
|
|
|
|
plantModel "sundynix-micro-go/app/plant/model"
|
|
)
|
|
|
|
func wikiVectorID(wikiID string) string {
|
|
return uuid.NewMD5(uuid.NameSpaceOID, []byte(wikiID)).String()
|
|
}
|
|
|
|
func buildWikiVectorText(w plantModel.Wiki) string {
|
|
return fmt.Sprintf("植物名字:%s. 拉丁名:%s. 别名:%s. 科属:%s. 分布区域:%s. 生命周期:%s. 生长习性:%s. 繁殖方法:%s. 病虫害:%s. 光照强度:%s. 光照类型:%s. 最佳温度:%s. 茎:%s. 叶型:%s. 叶色:%s. 叶形:%s. 高度:%d厘米. 开花期:%s. 花色:%s. 花形:%s. 花直径:%d厘米. 果实:%s.",
|
|
w.Name, w.LatinName, w.Aliases, w.Genus, w.DistributionArea, w.LifeCycle,
|
|
w.GrowthHabit, w.ReproductionMethod, w.PestsDiseases, w.LightIntensity,
|
|
w.LightType, w.OptimalTempPeriod, w.Stem, w.FoliageType, w.FoliageColor,
|
|
w.FoliageShape, w.Height, w.FloweringPeriod, w.FloweringColor,
|
|
w.FloweringShape, w.FlowerDiameter, w.Fruit)
|
|
}
|
|
|
|
func getActiveAiConfig(db *gorm.DB) (*plantModel.SysAiConfig, error) {
|
|
var cfg plantModel.SysAiConfig
|
|
if err := db.Where("is_active = 1").First(&cfg).Error; err != nil {
|
|
return nil, errors.New("数据库未找到已激活的 AI 配置")
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func embeddingModel(cfg *plantModel.SysAiConfig) string {
|
|
if cfg.EmbeddingModelName != "" {
|
|
return cfg.EmbeddingModelName
|
|
}
|
|
return "text-embedding-3-small"
|
|
}
|
|
|
|
func createEmbedding(ctx context.Context, cfg *plantModel.SysAiConfig, text string) ([]float32, error) {
|
|
config := openai.DefaultConfig(cfg.EmbeddingApiKey)
|
|
if cfg.EmbeddingApiUrl != "" {
|
|
config.BaseURL = cfg.EmbeddingApiUrl
|
|
}
|
|
client := openai.NewClientWithConfig(config)
|
|
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
|
|
Input: []string{text},
|
|
Model: openai.EmbeddingModel(embeddingModel(cfg)),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(resp.Data) == 0 || len(resp.Data[0].Embedding) == 0 {
|
|
return nil, errors.New("embedding 响应为空")
|
|
}
|
|
return resp.Data[0].Embedding, nil
|
|
}
|
|
|
|
func newQdrantConn(cfg *plantModel.SysAiConfig) (*grpc.ClientConn, context.Context, error) {
|
|
addr := strings.TrimPrefix(cfg.QdrantUrl, "http://")
|
|
addr = strings.TrimPrefix(addr, "https://")
|
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("qdrant grpc dial failed: %w", err)
|
|
}
|
|
ctx := context.Background()
|
|
if cfg.QdrantApiKey != "" {
|
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("api-key", cfg.QdrantApiKey))
|
|
}
|
|
return conn, ctx, nil
|
|
}
|
|
|
|
func ensureQdrantCollection(cfg *plantModel.SysAiConfig, dim int) error {
|
|
conn, ctx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
if dim <= 0 {
|
|
dim = cfg.VectorDimension
|
|
}
|
|
if dim <= 0 {
|
|
dim = 1536
|
|
}
|
|
collClient := qdrant.NewCollectionsClient(conn)
|
|
if _, getErr := collClient.Get(ctx, &qdrant.GetCollectionInfoRequest{CollectionName: cfg.QdrantCollection}); getErr == nil {
|
|
return nil
|
|
}
|
|
_, err = collClient.Create(ctx, &qdrant.CreateCollection{
|
|
CollectionName: cfg.QdrantCollection,
|
|
VectorsConfig: &qdrant.VectorsConfig{
|
|
Config: &qdrant.VectorsConfig_Params{
|
|
Params: &qdrant.VectorParams{Size: uint64(dim), Distance: qdrant.Distance_Cosine},
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
func upsertWikiVector(ctx context.Context, cfg *plantModel.SysAiConfig, w plantModel.Wiki) error {
|
|
text := buildWikiVectorText(w)
|
|
vector, err := createEmbedding(ctx, cfg, text)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := ensureQdrantCollection(cfg, len(vector)); err != nil {
|
|
return err
|
|
}
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
_, err = ptsClient.Upsert(qdCtx, &qdrant.UpsertPoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: []*qdrant.PointStruct{{
|
|
Id: qdrant.NewID(wikiVectorID(w.ID)),
|
|
Vectors: qdrant.NewVectors(vector...),
|
|
Payload: map[string]*qdrant.Value{
|
|
"wiki_id": qdrant.NewValueString(w.ID),
|
|
"name": qdrant.NewValueString(w.Name),
|
|
"full_text": qdrant.NewValueString(text),
|
|
},
|
|
}},
|
|
})
|
|
return err
|
|
}
|
|
|
|
func deleteWikiVector(ctx context.Context, cfg *plantModel.SysAiConfig, wikiID string) error {
|
|
conn, qdCtx, err := newQdrantConn(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
ptsClient := qdrant.NewPointsClient(conn)
|
|
|
|
_, err = ptsClient.Delete(qdCtx, &qdrant.DeletePoints{
|
|
CollectionName: cfg.QdrantCollection,
|
|
Points: &qdrant.PointsSelector{
|
|
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
|
|
Points: &qdrant.PointsIdsList{
|
|
Ids: []*qdrant.PointId{qdrant.NewID(wikiVectorID(wikiID))},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|