161 lines
4.8 KiB
Go
161 lines
4.8 KiB
Go
package logic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/md5"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
plantModel "sundynix-micro-go/app/plant/model"
|
|
"sundynix-micro-go/app/plant/rpc/internal/config"
|
|
)
|
|
|
|
func wikiVectorID(wikiID string) string {
|
|
sum := md5.Sum([]byte("sundynix-plant-wiki:" + wikiID))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func buildWikiVectorText(w plantModel.Wiki) string {
|
|
return fmt.Sprintf("植物名字:%s. 拉丁名:%s. 别名:%s. 科属:%s. 分布区域:%s. 生命周期:%s. 生长习性:%s. 繁殖方法:%s. 病虫害:%s. 光照强度:%s. 光照类型:%s. 最佳温度:%s. 茎:%s. 叶型:%s. 叶色:%s. 叶形:%s. 高度:%d厘米. 开花期:%s. 花色:%s. 花形:%s. 花直径:%d厘米. 果实:%s.",
|
|
w.Name, w.LatinName, w.Aliases, w.Genus, w.DistributionArea, w.LifeCycle,
|
|
w.GrowthHabit, w.ReproductionMethod, w.PestsDiseases, w.LightIntensity,
|
|
w.LightType, w.OptimalTempPeriod, w.Stem, w.FoliageType, w.FoliageColor,
|
|
w.FoliageShape, w.Height, w.FloweringPeriod, w.FloweringColor,
|
|
w.FloweringShape, w.FlowerDiameter, w.Fruit)
|
|
}
|
|
|
|
func embeddingModel(c config.Config) string {
|
|
if c.Ai.EmbeddingModelName != "" {
|
|
return c.Ai.EmbeddingModelName
|
|
}
|
|
return "text-embedding-3-small"
|
|
}
|
|
|
|
func createEmbedding(ctx context.Context, c config.Config, text string) ([]float32, error) {
|
|
body, _ := json.Marshal(map[string]interface{}{
|
|
"model": embeddingModel(c),
|
|
"input": text,
|
|
})
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Ai.EmbeddingApiUrl, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.Ai.EmbeddingApiKey)
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
return nil, fmt.Errorf("embedding 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw)))
|
|
}
|
|
var parsed struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(parsed.Data) == 0 || len(parsed.Data[0].Embedding) == 0 {
|
|
return nil, errors.New("embedding 响应为空")
|
|
}
|
|
return parsed.Data[0].Embedding, nil
|
|
}
|
|
|
|
func qdrantURL(c config.Config, path string) string {
|
|
return strings.TrimRight(c.Ai.QdrantUrl, "/") + path
|
|
}
|
|
|
|
func doQdrant(ctx context.Context, c config.Config, method, path string, body interface{}) error {
|
|
var reader io.Reader
|
|
if body != nil {
|
|
raw, _ := json.Marshal(body)
|
|
reader = bytes.NewReader(raw)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, method, qdrantURL(c, path), reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if c.Ai.QdrantApiKey != "" {
|
|
req.Header.Set("api-key", c.Ai.QdrantApiKey)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
return fmt.Errorf("qdrant 请求失败: %s %s", resp.Status, strings.TrimSpace(string(raw)))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureQdrantCollection(ctx context.Context, c config.Config, dim int) error {
|
|
getReq, err := http.NewRequestWithContext(ctx, http.MethodGet, qdrantURL(c, "/collections/"+c.Ai.QdrantCollection), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if c.Ai.QdrantApiKey != "" {
|
|
getReq.Header.Set("api-key", c.Ai.QdrantApiKey)
|
|
}
|
|
if resp, err := http.DefaultClient.Do(getReq); err == nil {
|
|
_ = resp.Body.Close()
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return nil
|
|
}
|
|
}
|
|
if dim <= 0 {
|
|
dim = c.Ai.VectorDimension
|
|
}
|
|
if dim <= 0 {
|
|
dim = 1536
|
|
}
|
|
return doQdrant(ctx, c, http.MethodPut, "/collections/"+c.Ai.QdrantCollection, map[string]interface{}{
|
|
"vectors": map[string]interface{}{
|
|
"size": dim,
|
|
"distance": "Cosine",
|
|
},
|
|
})
|
|
}
|
|
|
|
func upsertWikiVector(ctx context.Context, c config.Config, w plantModel.Wiki) error {
|
|
text := buildWikiVectorText(w)
|
|
vector, err := createEmbedding(ctx, c, text)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := ensureQdrantCollection(ctx, c, len(vector)); err != nil {
|
|
return err
|
|
}
|
|
return doQdrant(ctx, c, http.MethodPut, "/collections/"+c.Ai.QdrantCollection+"/points?wait=true", map[string]interface{}{
|
|
"points": []map[string]interface{}{
|
|
{
|
|
"id": wikiVectorID(w.ID),
|
|
"vector": vector,
|
|
"payload": map[string]interface{}{
|
|
"wiki_id": w.ID,
|
|
"name": w.Name,
|
|
"full_text": text,
|
|
},
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func deleteWikiVector(ctx context.Context, c config.Config, wikiID string) error {
|
|
return doQdrant(ctx, c, http.MethodPost, "/collections/"+c.Ai.QdrantCollection+"/points/delete?wait=true", map[string]interface{}{
|
|
"points": []string{wikiVectorID(wikiID)},
|
|
})
|
|
}
|