204 lines
8.2 KiB
Go
204 lines
8.2 KiB
Go
// Package service 实现 OpenAI-compatible 流式 AI 调用(Server-Sent Events)。
|
||
//
|
||
// # 流式输出设计
|
||
//
|
||
// 本模块使用"流式请求 + Channel 推送"模式,而非"等待完整响应",原因:
|
||
// 1. 用户体验:DeepSeek/GPT 生成一条回复通常需要 3-30 秒,
|
||
// 流式输出让用户看到"逐字打印"效果,而非盯着空白等待;
|
||
// 2. 内存效率:完整回复可能超过 4096 token,流式逐块处理不会
|
||
// 在内存中积累超大字符串;
|
||
// 3. 可中断性:配合 context.WithCancel,用户随时可点击"停止"按钮
|
||
// 立即中断生成,而"一次性请求"无法在途中取消。
|
||
//
|
||
// # OpenAI 兼容协议
|
||
//
|
||
// DeepSeek、通义千问、ERNIE 等国内模型均提供"OpenAI 兼容接口",
|
||
// 支持完全相同的请求格式(/v1/chat/completions)和 SSE 响应格式。
|
||
// 本模块通过 AICallConfig.BaseURL 支持任意兼容端点,
|
||
// 用户只需在设置页填入对应的 API 地址即可切换模型,无需改代码。
|
||
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// AICallConfig 封装单次 AI 调用所需的全部配置,由 ResolveAIConfig() 在调用前动态解析。
|
||
//
|
||
// 之所以用 struct 而非全局变量,是为了让每次调用都能独立配置,
|
||
// 方便未来扩展"对话级 prompt 覆盖"或"A/B 测试不同模型"而不互相干扰。
|
||
type AICallConfig struct {
|
||
// BaseURL 是 API 端点,如 "https://api.deepseek.com/chat/completions"。
|
||
// 支持自定义,兼容所有 OpenAI-compatible 路由。
|
||
BaseURL string
|
||
// APIKey 对应 HTTP Header: Authorization: Bearer <APIKey>。
|
||
APIKey string
|
||
// Model 是模型标识符,如 "deepseek-chat" 或 "gpt-4o"。
|
||
Model string
|
||
// MaxTokens 限制单次生成的最大 token 数,防止意外超长输出(也控制费用)。
|
||
MaxTokens int
|
||
// SystemPrompt 是用户在设置页自定义的系统提示词,
|
||
// 覆盖 BuildRAGMessages 中的内置模板。
|
||
SystemPrompt string
|
||
}
|
||
|
||
// ── OpenAI-compatible 请求/响应结构体 ─────────────────────────────────────────
|
||
|
||
// dsMessage 对应 OpenAI messages 数组中的单条消息。
|
||
// Role: "system" | "user" | "assistant"
|
||
type dsMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// dsRequest 是发送给 API 的完整请求体。
|
||
// Stream:true 触发服务端以 text/event-stream 格式逐块返回,而非一次性 JSON。
|
||
type dsRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []dsMessage `json:"messages"`
|
||
Stream bool `json:"stream"`
|
||
MaxTokens int `json:"max_tokens,omitempty"` // omitempty:0 时省略字段,用服务端默认值
|
||
}
|
||
|
||
// dsDelta / dsChoice / dsSSELine 是 SSE 流中每个 data: {...} 行的反序列化目标。
|
||
// 每行格式:{"choices":[{"delta":{"content":"你好"},"finish_reason":null}]}
|
||
// finish_reason 非 null 时表示生成完成,但我们用 [DONE] 标记作为终止信号更可靠。
|
||
type dsDelta struct {
|
||
Content string `json:"content"`
|
||
}
|
||
type dsChoice struct {
|
||
Delta dsDelta `json:"delta"`
|
||
FinishReason string `json:"finish_reason"`
|
||
}
|
||
type dsSSELine struct {
|
||
Choices []dsChoice `json:"choices"`
|
||
}
|
||
|
||
// CallDeepSeekStream 向任意 OpenAI-compatible 端点发起流式 Chat 请求。
|
||
//
|
||
// 参数说明:
|
||
// - ctx: 携带取消信号,用户点击"停止"时 context 被取消,HTTP 请求立即中止;
|
||
// - cfg: 本次调用的 AI 配置(BaseURL/APIKey/Model),由调用方从数据库解析;
|
||
// - messages: OpenAI 格式的对话历史,包含 system 和 user 两条消息;
|
||
// - streamCh: 写端 channel,每解析到一个字符片段就推送进去;
|
||
// 接收端(handler/expert.go)通过 runtime.EventsEmit 转发给前端。
|
||
//
|
||
// 返回 nil 表示流正常结束(收到 [DONE]),返回 error 表示网络或协议错误。
|
||
func CallDeepSeekStream(ctx context.Context, cfg AICallConfig, messages []dsMessage, streamCh chan<- string) error {
|
||
if cfg.APIKey == "" {
|
||
return fmt.Errorf("API key 未配置,请在设置中填写或联系管理员")
|
||
}
|
||
|
||
payload := dsRequest{Model: cfg.Model, Messages: messages, Stream: true, MaxTokens: cfg.MaxTokens}
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return fmt.Errorf("marshal request: %w", err)
|
||
}
|
||
|
||
// Timeout 60s:流式请求通常在 30s 内完成,60s 留出余量;
|
||
// 不使用无限超时,防止网络异常时 goroutine 永久挂起
|
||
client := &http.Client{Timeout: 60 * time.Second}
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return fmt.Errorf("build request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "text/event-stream") // 告知服务端客户端期望 SSE 格式
|
||
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
if ctx.Err() != nil {
|
||
return ctx.Err() // 优先返回 context 取消错误,便于上层区分"用户停止"
|
||
}
|
||
return fmt.Errorf("http request: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
errBody, _ := io.ReadAll(resp.Body)
|
||
return fmt.Errorf("上游返回 %d: %s", resp.StatusCode, string(errBody))
|
||
}
|
||
|
||
return parseDeepSeekSSE(ctx, resp.Body, streamCh)
|
||
}
|
||
|
||
// BuildRAGMessages 构造 OpenAI-compatible messages 数组,实现 RAG 增强。
|
||
//
|
||
// # RAG(检索增强生成)原理
|
||
//
|
||
// 传统:直接把用户问题发给 AI → AI 靠训练知识回答(可能不准确)。
|
||
// RAG:先在本地知识库检索相关内容 → 将检索结果放入 system prompt →
|
||
//
|
||
// AI 优先参考本地知识回答 → 答案更贴合业务场景。
|
||
//
|
||
// knowledgeContext 是检索到的本地问答片段(由 buildKnowledgeContext 生成)。
|
||
// customSystemPrompt 非空时替换内置的客服模板,让用户完全自定义 AI 人设。
|
||
func BuildRAGMessages(knowledgeContext, userQuery, customSystemPrompt string) []dsMessage {
|
||
var systemContent string
|
||
if customSystemPrompt != "" {
|
||
// 用户自定义 prompt 优先,但仍然追加本地知识作为参考
|
||
systemContent = customSystemPrompt
|
||
if knowledgeContext != "" && knowledgeContext != "(无相关本地知识)" {
|
||
systemContent += "\n\n以下是本地知识库中的相关内容供参考:\n---\n" + knowledgeContext + "\n---"
|
||
}
|
||
} else {
|
||
// 内置模板:默认为"客服顾问"角色,优先参考本地知识库内容
|
||
systemContent = fmt.Sprintf(
|
||
"你是一位专业的客服顾问,擅长用温暖、自然、有亲和力的语气沟通。\n\n"+
|
||
"以下是来自本地知识库的相关内容,请优先参考:\n\n---\n%s\n---\n\n"+
|
||
"根据以上知识润色话术,直接输出内容,不加前缀或解释。",
|
||
knowledgeContext,
|
||
)
|
||
}
|
||
return []dsMessage{
|
||
{Role: "system", Content: systemContent},
|
||
{Role: "user", Content: userQuery},
|
||
}
|
||
}
|
||
|
||
// parseDeepSeekSSE 逐行解析 SSE(Server-Sent Events)流,
|
||
// 提取每个 delta.content 片段并推入 channel。
|
||
//
|
||
// SSE 协议规则(OpenAI 子集):
|
||
// - 以 "data: " 开头的行包含 JSON payload;
|
||
// - "data: [DONE]" 是终止标记,收到后停止扫描;
|
||
// - 其他行(空行、": comment" 等)直接忽略。
|
||
//
|
||
// bufio.Scanner 缓冲区设置为 64 KB,防止超长单行(如图片 base64)超出默认的 64 KB 限制。
|
||
func parseDeepSeekSSE(ctx context.Context, body io.Reader, ch chan<- string) error {
|
||
scanner := bufio.NewScanner(body)
|
||
scanner.Buffer(make([]byte, 64*1024), 64*1024)
|
||
|
||
for scanner.Scan() {
|
||
if ctx.Err() != nil {
|
||
return ctx.Err() // 检查取消信号,及时退出扫描循环
|
||
}
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data:") {
|
||
continue
|
||
}
|
||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||
if data == "[DONE]" {
|
||
break
|
||
}
|
||
var event dsSSELine
|
||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||
continue // 解析失败的行静默跳过,不中断整个流
|
||
}
|
||
if len(event.Choices) > 0 {
|
||
if chunk := event.Choices[0].Delta.Content; chunk != "" {
|
||
ch <- chunk // 推入 channel,由 handler 层转发给前端
|
||
}
|
||
}
|
||
}
|
||
return scanner.Err()
|
||
}
|