package rag import ( "context" "fmt" "sync" "github.com/milvus-io/milvus-sdk-go/v2/client" "github.com/milvus-io/milvus-sdk-go/v2/entity" ) const collection = "sundynix_wiki" // Wiki/知识库向量集合 // milvusStore 封装 Milvus 连接与集合管理(集合按首次写入的向量维度懒建)。 type milvusStore struct { cli client.Client mu sync.Mutex dim int // 已建集合的维度(0=未建) ok bool // 集合是否就绪 } func openMilvus(ctx context.Context, addr string) (*milvusStore, error) { cli, err := client.NewClient(ctx, client.Config{Address: addr}) if err != nil { return nil, err } return &milvusStore{cli: cli}, nil } func (m *milvusStore) close() { if m != nil && m.cli != nil { _ = m.cli.Close() } } // ensure 幂等地按维度 dim 建集合 + 向量索引 + 加载(首次写入时调用)。 func (m *milvusStore) ensure(ctx context.Context, dim int) error { m.mu.Lock() defer m.mu.Unlock() if m.ok && m.dim == dim { return nil } has, err := m.cli.HasCollection(ctx, collection) if err != nil { return err } if !has { schema := entity.NewSchema().WithName(collection).WithDescription("sundynix wiki vectors"). WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). WithField(entity.NewField().WithName("kb").WithDataType(entity.FieldTypeVarChar).WithMaxLength(64)). WithField(entity.NewField().WithName("text").WithDataType(entity.FieldTypeVarChar).WithMaxLength(8192)). WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(int64(dim))) if err := m.cli.CreateCollection(ctx, schema, 1); err != nil { return fmt.Errorf("create collection: %w", err) } idx, _ := entity.NewIndexAUTOINDEX(entity.COSINE) if err := m.cli.CreateIndex(ctx, collection, "vector", idx, false); err != nil { return fmt.Errorf("create index: %w", err) } } if err := m.cli.LoadCollection(ctx, collection, false); err != nil { return fmt.Errorf("load collection: %w", err) } m.dim, m.ok = dim, true return nil } // insert 写入若干 (kb, text, vector)。 func (m *milvusStore) insert(ctx context.Context, kb string, texts []string, vecs [][]float32) error { if len(vecs) == 0 { return nil } if err := m.ensure(ctx, len(vecs[0])); err != nil { return err } kbs := make([]string, len(texts)) for i := range kbs { kbs[i] = kb } _, err := m.cli.Insert(ctx, collection, "", entity.NewColumnVarChar("kb", kbs), entity.NewColumnVarChar("text", texts), entity.NewColumnFloatVector("vector", len(vecs[0]), vecs), ) if err != nil { return fmt.Errorf("insert: %w", err) } return m.cli.Flush(ctx, collection, false) } // Hit 是一条检索结果。 type Hit struct { Text string Score float32 } // search 用查询向量做 topK 向量检索(可按 kb 过滤)。 func (m *milvusStore) search(ctx context.Context, kb string, qvec []float32, topK int) ([]Hit, error) { if !m.ok { // 集合未建(还没入过库)→ 尝试确保(按查询维度),无则空结果。 if err := m.ensure(ctx, len(qvec)); err != nil { return nil, nil } } expr := "" if kb != "" { expr = fmt.Sprintf("kb == \"%s\"", kb) } sp, _ := entity.NewIndexAUTOINDEXSearchParam(1) results, err := m.cli.Search(ctx, collection, nil, expr, []string{"text"}, []entity.Vector{entity.FloatVector(qvec)}, "vector", entity.COSINE, topK, sp) if err != nil { return nil, fmt.Errorf("search: %w", err) } var hits []Hit for _, r := range results { textCol := r.Fields.GetColumn("text") for i := 0; i < r.ResultCount; i++ { text := "" if textCol != nil { if s, err := textCol.GetAsString(i); err == nil { text = s } } var score float32 if i < len(r.Scores) { score = r.Scores[i] } hits = append(hits, Hit{Text: text, Score: score}) } } return hits, nil }