package bus_test import ( "context" "encoding/json" "testing" "time" natsserver "github.com/nats-io/nats-server/v2/server" natstest "github.com/nats-io/nats-server/v2/test" "github.com/sundynix/sundynix-shared/bus" "github.com/sundynix/sundynix-shared/contract" ) // startEmbeddedNATS 启动一个内嵌、开启 JetStream 的 NATS 服务器,免 Docker。 func startEmbeddedNATS(t *testing.T) string { t.Helper() opts := natstest.DefaultTestOptions opts.Port = -1 // 随机端口 opts.JetStream = true opts.StoreDir = t.TempDir() srv := natstest.RunServer(&opts) if !srv.ReadyForConnections(5 * time.Second) { t.Fatal("embedded NATS not ready") } t.Cleanup(srv.Shutdown) _ = natsserver.Server{} // 触发包引用 return srv.ClientURL() } // TestTaskRoundTrip 模拟 Gateway 发布 → NATS → Dispatcher 消费 的完整任务流。 func TestTaskRoundTrip(t *testing.T) { url := startEmbeddedNATS(t) // --- Gateway 侧:连接并声明任务流 --- gw, err := bus.Connect(url) if err != nil { t.Fatalf("gateway connect: %v", err) } defer gw.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := gw.EnsureTaskStream(ctx); err != nil { t.Fatalf("ensure stream: %v", err) } // --- Dispatcher 侧:连接并开始消费 --- dp, err := bus.Connect(url) if err != nil { t.Fatalf("dispatcher connect: %v", err) } defer dp.Close() got := make(chan *contract.Task, 1) stop, err := dp.ConsumeTasks(ctx, func(_ context.Context, task *contract.Task) error { got <- task return nil }) if err != nil { t.Fatalf("consume: %v", err) } defer stop() // --- Gateway 发布一个任务 --- want := &contract.Task{ ID: "task_demo_001", Graph: json.RawMessage(`{"nodes":[{"id":"n1","type":"agent"}],"edges":[]}`), Meta: map[string]any{"user": "wt"}, } seq, err := gw.PublishTask(ctx, want) if err != nil { t.Fatalf("publish: %v", err) } if seq == 0 { t.Fatal("expected non-zero stream sequence") } // --- 断言 Dispatcher 收到同一个任务 --- select { case task := <-got: if task.ID != want.ID { t.Fatalf("task id = %q, want %q", task.ID, want.ID) } if task.Meta["user"] != "wt" { t.Fatalf("task meta lost: %+v", task.Meta) } t.Logf("✓ 任务流打通:Gateway publish (seq=%d) → NATS → Dispatcher consume,task_id=%s", seq, task.ID) case <-time.After(5 * time.Second): t.Fatal("timeout: dispatcher 未收到任务") } } // TestTokenStreamRoundTrip 模拟 Dispatcher 回流 Token → Gateway 订阅 的流式闭环。 func TestTokenStreamRoundTrip(t *testing.T) { url := startEmbeddedNATS(t) // Gateway 侧:先订阅(core NATS 无持久化,须先连)。 gw, err := bus.Connect(url) if err != nil { t.Fatalf("gateway connect: %v", err) } defer gw.Close() const taskID = "task_stream_001" var got []string done := make(chan struct{}) unsub, err := gw.SubscribeTokens(taskID, func(tok []byte) { got = append(got, string(tok)) }, func() { close(done) }, ) if err != nil { t.Fatalf("subscribe tokens: %v", err) } defer func() { _ = unsub() }() // Dispatcher 侧:逐 Token 回流后发结束信号。 dp, err := bus.Connect(url) if err != nil { t.Fatalf("dispatcher connect: %v", err) } defer dp.Close() want := []string{"Hello", " ", "Agent", "!"} for _, tok := range want { if err := dp.PublishToken(taskID, []byte(tok)); err != nil { t.Fatalf("publish token: %v", err) } } if err := dp.CompleteStream(taskID); err != nil { t.Fatalf("complete stream: %v", err) } select { case <-done: joined := "" for _, s := range got { joined += s } if joined != "Hello Agent!" { t.Fatalf("token stream = %q, want %q", joined, "Hello Agent!") } t.Logf("✓ Token 流闭环:Dispatcher 回流 %d 个 token → Gateway 拼回 %q", len(got), joined) case <-time.After(5 * time.Second): t.Fatal("timeout: 未收到流结束信号") } }