Files
AI-Writie-Assistant/internal/chat/llm_client.go
T
2026-04-07 17:35:09 +08:00

134 lines
3.3 KiB
Go

package chat
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// LLMClient provides a unified interface for OpenAI-compatible LLM APIs.
type LLMClient struct {
client *http.Client
}
// NewLLMClient creates a new LLM client.
func NewLLMClient() *LLMClient {
return &LLMClient{
client: &http.Client{Timeout: 120 * time.Second},
}
}
// Message represents a chat message in OpenAI format.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest is the OpenAI chat completion request.
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// ChatResponse is a non-streaming response.
type ChatResponse struct {
Choices []struct {
Message Message `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
// StreamDelta from SSE streaming.
type streamChunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
}
// Complete sends a non-streaming chat request.
func (c *LLMClient) Complete(baseURL, apiKey, model string, messages []Message) (*ChatResponse, error) {
reqBody := ChatRequest{
Model: model,
Messages: messages,
Stream: false,
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", baseURL+"/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("llm request: %w", err)
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result ChatResponse
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("parse llm response: %w", err)
}
return &result, nil
}
// StreamComplete sends a streaming chat request, calling onChunk for each token.
func (c *LLMClient) StreamComplete(baseURL, apiKey, model string, messages []Message, onChunk func(string)) error {
reqBody := ChatRequest{
Model: model,
Messages: messages,
Stream: true,
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", baseURL+"/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("llm stream request: %w", err)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk streamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
onChunk(chunk.Choices[0].Delta.Content)
}
}
return scanner.Err()
}