Files
AI-Expert-Sidebar/internal/service/ai_bridge.go
T
2026-04-01 15:29:35 +08:00

204 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package service 实现 OpenAI-compatible 流式 AI 调用(Server-Sent Events)。
//
// # 流式输出设计
//
// 本模块使用"流式请求 + Channel 推送"模式,而非"等待完整响应",原因:
// 1. 用户体验:DeepSeek/GPT 生成一条回复通常需要 3-30 秒,
// 流式输出让用户看到"逐字打印"效果,而非盯着空白等待;
// 2. 内存效率:完整回复可能超过 4096 token,流式逐块处理不会
// 在内存中积累超大字符串;
// 3. 可中断性:配合 context.WithCancel,用户随时可点击"停止"按钮
// 立即中断生成,而"一次性请求"无法在途中取消。
//
// # OpenAI 兼容协议
//
// DeepSeek、通义千问、ERNIE 等国内模型均提供"OpenAI 兼容接口"
// 支持完全相同的请求格式(/v1/chat/completions)和 SSE 响应格式。
// 本模块通过 AICallConfig.BaseURL 支持任意兼容端点,
// 用户只需在设置页填入对应的 API 地址即可切换模型,无需改代码。
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// AICallConfig 封装单次 AI 调用所需的全部配置,由 ResolveAIConfig() 在调用前动态解析。
//
// 之所以用 struct 而非全局变量,是为了让每次调用都能独立配置,
// 方便未来扩展"对话级 prompt 覆盖"或"A/B 测试不同模型"而不互相干扰。
type AICallConfig struct {
// BaseURL 是 API 端点,如 "https://api.deepseek.com/chat/completions"。
// 支持自定义,兼容所有 OpenAI-compatible 路由。
BaseURL string
// APIKey 对应 HTTP Header: Authorization: Bearer <APIKey>。
APIKey string
// Model 是模型标识符,如 "deepseek-chat" 或 "gpt-4o"。
Model string
// MaxTokens 限制单次生成的最大 token 数,防止意外超长输出(也控制费用)。
MaxTokens int
// SystemPrompt 是用户在设置页自定义的系统提示词,
// 覆盖 BuildRAGMessages 中的内置模板。
SystemPrompt string
}
// ── OpenAI-compatible 请求/响应结构体 ─────────────────────────────────────────
// dsMessage 对应 OpenAI messages 数组中的单条消息。
// Role: "system" | "user" | "assistant"
type dsMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// dsRequest 是发送给 API 的完整请求体。
// Stream:true 触发服务端以 text/event-stream 格式逐块返回,而非一次性 JSON。
type dsRequest struct {
Model string `json:"model"`
Messages []dsMessage `json:"messages"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens,omitempty"` // omitempty:0 时省略字段,用服务端默认值
}
// dsDelta / dsChoice / dsSSELine 是 SSE 流中每个 data: {...} 行的反序列化目标。
// 每行格式:{"choices":[{"delta":{"content":"你好"},"finish_reason":null}]}
// finish_reason 非 null 时表示生成完成,但我们用 [DONE] 标记作为终止信号更可靠。
type dsDelta struct {
Content string `json:"content"`
}
type dsChoice struct {
Delta dsDelta `json:"delta"`
FinishReason string `json:"finish_reason"`
}
type dsSSELine struct {
Choices []dsChoice `json:"choices"`
}
// CallDeepSeekStream 向任意 OpenAI-compatible 端点发起流式 Chat 请求。
//
// 参数说明:
// - ctx: 携带取消信号,用户点击"停止"时 context 被取消,HTTP 请求立即中止;
// - cfg: 本次调用的 AI 配置(BaseURL/APIKey/Model),由调用方从数据库解析;
// - messages: OpenAI 格式的对话历史,包含 system 和 user 两条消息;
// - streamCh: 写端 channel,每解析到一个字符片段就推送进去;
// 接收端(handler/expert.go)通过 runtime.EventsEmit 转发给前端。
//
// 返回 nil 表示流正常结束(收到 [DONE]),返回 error 表示网络或协议错误。
func CallDeepSeekStream(ctx context.Context, cfg AICallConfig, messages []dsMessage, streamCh chan<- string) error {
if cfg.APIKey == "" {
return fmt.Errorf("API key 未配置,请在设置中填写或联系管理员")
}
payload := dsRequest{Model: cfg.Model, Messages: messages, Stream: true, MaxTokens: cfg.MaxTokens}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal request: %w", err)
}
// Timeout 60s:流式请求通常在 30s 内完成,60s 留出余量;
// 不使用无限超时,防止网络异常时 goroutine 永久挂起
client := &http.Client{Timeout: 60 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream") // 告知服务端客户端期望 SSE 格式
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
resp, err := client.Do(req)
if err != nil {
if ctx.Err() != nil {
return ctx.Err() // 优先返回 context 取消错误,便于上层区分"用户停止"
}
return fmt.Errorf("http request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("上游返回 %d: %s", resp.StatusCode, string(errBody))
}
return parseDeepSeekSSE(ctx, resp.Body, streamCh)
}
// BuildRAGMessages 构造 OpenAI-compatible messages 数组,实现 RAG 增强。
//
// # RAG(检索增强生成)原理
//
// 传统:直接把用户问题发给 AI → AI 靠训练知识回答(可能不准确)。
// RAG:先在本地知识库检索相关内容 → 将检索结果放入 system prompt →
//
// AI 优先参考本地知识回答 → 答案更贴合业务场景。
//
// knowledgeContext 是检索到的本地问答片段(由 buildKnowledgeContext 生成)。
// customSystemPrompt 非空时替换内置的客服模板,让用户完全自定义 AI 人设。
func BuildRAGMessages(knowledgeContext, userQuery, customSystemPrompt string) []dsMessage {
var systemContent string
if customSystemPrompt != "" {
// 用户自定义 prompt 优先,但仍然追加本地知识作为参考
systemContent = customSystemPrompt
if knowledgeContext != "" && knowledgeContext != "(无相关本地知识)" {
systemContent += "\n\n以下是本地知识库中的相关内容供参考:\n---\n" + knowledgeContext + "\n---"
}
} else {
// 内置模板:默认为"客服顾问"角色,优先参考本地知识库内容
systemContent = fmt.Sprintf(
"你是一位专业的客服顾问,擅长用温暖、自然、有亲和力的语气沟通。\n\n"+
"以下是来自本地知识库的相关内容,请优先参考:\n\n---\n%s\n---\n\n"+
"根据以上知识润色话术,直接输出内容,不加前缀或解释。",
knowledgeContext,
)
}
return []dsMessage{
{Role: "system", Content: systemContent},
{Role: "user", Content: userQuery},
}
}
// parseDeepSeekSSE 逐行解析 SSEServer-Sent Events)流,
// 提取每个 delta.content 片段并推入 channel。
//
// SSE 协议规则(OpenAI 子集):
// - 以 "data: " 开头的行包含 JSON payload
// - "data: [DONE]" 是终止标记,收到后停止扫描;
// - 其他行(空行、": comment" 等)直接忽略。
//
// bufio.Scanner 缓冲区设置为 64 KB,防止超长单行(如图片 base64)超出默认的 64 KB 限制。
func parseDeepSeekSSE(ctx context.Context, body io.Reader, ch chan<- string) error {
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 64*1024)
for scanner.Scan() {
if ctx.Err() != nil {
return ctx.Err() // 检查取消信号,及时退出扫描循环
}
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "[DONE]" {
break
}
var event dsSSELine
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue // 解析失败的行静默跳过,不中断整个流
}
if len(event.Choices) > 0 {
if chunk := event.Choices[0].Delta.Content; chunk != "" {
ch <- chunk // 推入 channel,由 handler 层转发给前端
}
}
}
return scanner.Err()
}