package dsl import ( "encoding/json" "testing" ) func graphJSON(t *testing.T, f Flow) json.RawMessage { t.Helper() b, err := json.Marshal(f) if err != nil { t.Fatal(err) } return b } func TestTopo_Linear(t *testing.T) { f := Flow{ Nodes: []Node{{ID: "i"}, {ID: "a"}, {ID: "o"}}, Edges: []Edge{{Source: "i", Target: "a"}, {Source: "a", Target: "o"}}, } order := f.Topo() ids := []string{order[0].ID, order[1].ID, order[2].ID} want := []string{"i", "a", "o"} for k := range want { if ids[k] != want[k] { t.Fatalf("Topo 顺序 = %v, want %v", ids, want) } } } func TestTopo_CycleFallback(t *testing.T) { // 有环:Kahn 排不出,退化为声明序补齐,但不能丢节点。 f := Flow{ Nodes: []Node{{ID: "a"}, {ID: "b"}}, Edges: []Edge{{Source: "a", Target: "b"}, {Source: "b", Target: "a"}}, } if got := f.Topo(); len(got) != 2 { t.Fatalf("有环也应返回全部节点, got %d", len(got)) } } func TestCompile(t *testing.T) { raw := graphJSON(t, Flow{Nodes: []Node{ {ID: "i", Kind: "input", Config: map[string]any{"text": "你好"}}, {ID: "a", Kind: "agent", Config: map[string]any{"system": "你是助手", "prompt": "再见"}}, {ID: "t", Kind: "tool", Config: map[string]any{"tool": "wiki_search"}}, }}) p := Compile(raw) if p.System != "你是助手" { t.Errorf("System = %q", p.System) } if p.Query != "你好\n再见" { t.Errorf("Query = %q", p.Query) } if len(p.Tools) != 1 || p.Tools[0] != "wiki_search" { t.Errorf("Tools = %v", p.Tools) } } func TestCompile_NoInputFallback(t *testing.T) { // 有节点但无 input/agent 输入 → System 兜底默认,Query 兜底 "你好"。 raw := graphJSON(t, Flow{Nodes: []Node{{ID: "t", Kind: "tool", Config: map[string]any{"tool": "wiki_search"}}}}) p := Compile(raw) if p.System != defaultSystem { t.Errorf("System 应兜底, got %q", p.System) } if p.Query != "你好" { t.Errorf("无输入 Query 应兜底为 你好, got %q", p.Query) } } func TestCompile_EmptyGraphUsesRaw(t *testing.T) { // 无节点(无法解析为结构化图)→ 兼容旧行为:原文当输入。 p := Compile(json.RawMessage(`帮我写个东西`)) if p.Query != "帮我写个东西" { t.Errorf("空图应把原文当 Query, got %q", p.Query) } } func TestToolBinding(t *testing.T) { tool, args := ToolBinding(Node{Kind: "tool", Config: map[string]any{"tool": "x", "args": `{"q":"v"}`}}) if tool != "x" || args["q"] != "v" { t.Errorf("tool 节点 = %q %v", tool, args) } rtool, rargs := ToolBinding(Node{Kind: "retriever", Config: map[string]any{"kb": "docs"}}) if rtool != "wiki_search" || rargs["kb"] != "docs" { t.Errorf("retriever 节点 = %q %v", rtool, rargs) } if ntool, _ := ToolBinding(Node{Kind: "agent"}); ntool != "" { t.Errorf("非工具节点应返回空工具, got %q", ntool) } } func TestParse(t *testing.T) { if _, err := Parse(json.RawMessage(`{"nodes":[{"id":"a"}],"edges":[]}`)); err != nil { t.Errorf("合法 DSL 不应报错: %v", err) } if _, err := Parse(json.RawMessage(`{bad json`)); err == nil { t.Error("非法 JSON 应报错") } }