package rag import ( "context" "encoding/json" "fmt" "log" "strings" "github.com/neo4j/neo4j-go-driver/v5/neo4j" "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" ) // Triple 是一条知识三元组(主体-关系-客体)。 type Triple struct { S string `json:"s"` P string `json:"p"` O string `json:"o"` } // graphStore 是 GraphRAG 的图路:实体/关系存 Neo4j。 type graphStore struct { driver neo4j.DriverWithContext } func openGraph(ctx context.Context, uri, user, pass string) *graphStore { if uri == "" { return &graphStore{} } drv, err := neo4j.NewDriverWithContext(uri, auth.BasicTokenManager(func(context.Context) (neo4j.AuthToken, error) { return neo4j.BasicAuth(user, pass, ""), nil })) if err != nil { log.Printf("[rag] Neo4j 连接失败,图谱路降级: %v", err) return &graphStore{} } if err := drv.VerifyConnectivity(ctx); err != nil { log.Printf("[rag] Neo4j 不可用,图谱路降级: %v", err) return &graphStore{} } // 实体唯一约束(kb+name)。 _, _ = neo4j.ExecuteQuery(ctx, drv, "CREATE CONSTRAINT entity_key IF NOT EXISTS FOR (e:Entity) REQUIRE (e.kb, e.name) IS UNIQUE", nil, neo4j.EagerResultTransformer) log.Printf("[rag] Neo4j connected %s", uri) return &graphStore{driver: drv} } func (g *graphStore) ready() bool { return g != nil && g.driver != nil } func (g *graphStore) close(ctx context.Context) { if g.ready() { _ = g.driver.Close(ctx) } } // store 把三元组 MERGE 进 Neo4j(实体 + 关系,按 kb 隔离)。 func (g *graphStore) store(ctx context.Context, kb string, triples []Triple) (int, error) { if !g.ready() { return 0, nil } n := 0 for _, t := range triples { if t.S == "" || t.O == "" || t.P == "" { continue } _, err := neo4j.ExecuteQuery(ctx, g.driver, `MERGE (a:Entity {kb:$kb, name:$s}) MERGE (b:Entity {kb:$kb, name:$o}) MERGE (a)-[r:REL {type:$p}]->(b)`, map[string]any{"kb": kb, "s": t.S, "o": t.O, "p": t.P}, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithDatabase("neo4j")) if err != nil { return n, err } n++ } return n, nil } // search 图谱召回:找查询里提到的实体,返回其相连三元组(文本化)。 func (g *graphStore) search(ctx context.Context, kb, query string, limit int) []Hit { if !g.ready() || query == "" { return nil } res, err := neo4j.ExecuteQuery(ctx, g.driver, `MATCH (a:Entity {kb:$kb})-[r:REL]->(b:Entity {kb:$kb}) WHERE $q CONTAINS a.name OR $q CONTAINS b.name RETURN a.name AS s, r.type AS p, b.name AS o LIMIT $k`, map[string]any{"kb": kb, "q": query, "k": limit}, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithDatabase("neo4j")) if err != nil { return nil } var hits []Hit for _, rec := range res.Records { s, _ := rec.Get("s") p, _ := rec.Get("p") o, _ := rec.Get("o") hits = append(hits, Hit{Text: fmt.Sprintf("%v —%v→ %v", s, p, o), Score: 1}) } return hits } // triples 返回某 kb 的全部三元组(供 UI 图谱可视化)。 func (g *graphStore) triples(ctx context.Context, kb string, limit int) []Triple { if !g.ready() { return nil } res, err := neo4j.ExecuteQuery(ctx, g.driver, `MATCH (a:Entity {kb:$kb})-[r:REL]->(b:Entity {kb:$kb}) RETURN a.name AS s, r.type AS p, b.name AS o LIMIT $k`, map[string]any{"kb": kb, "k": limit}, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithDatabase("neo4j")) if err != nil { return nil } var out []Triple for _, rec := range res.Records { s, _ := rec.Get("s") p, _ := rec.Get("p") o, _ := rec.Get("o") out = append(out, Triple{S: fmt.Sprint(s), P: fmt.Sprint(p), O: fmt.Sprint(o)}) } return out } // extractTriples 用 LLM 从文本抽取知识三元组。 func extractTriples(ctx context.Context, chat *chatClient, text string) ([]Triple, error) { if !chat.ready() { return nil, nil } const sys = "你是知识图谱抽取器。从用户文本中抽取知识三元组,输出 JSON 数组,每项形如 {\"s\":\"主体\",\"p\":\"关系\",\"o\":\"客体\"}。实体用简洁名词,关系用简短动词短语。只输出 JSON,不要任何解释或代码块标记。" out, err := chat.complete(ctx, sys, text) if err != nil { return nil, err } return parseTriples(out), nil } // parseTriples 容忍代码块/前后噪声地解析三元组 JSON。 func parseTriples(s string) []Triple { s = strings.TrimSpace(s) s = strings.TrimPrefix(s, "```json") s = strings.TrimPrefix(s, "```") s = strings.TrimSuffix(s, "```") if i := strings.Index(s, "["); i >= 0 { if j := strings.LastIndex(s, "]"); j > i { s = s[i : j+1] } } var triples []Triple if json.Unmarshal([]byte(s), &triples) != nil { return nil } return triples }