init: initial commit
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AICallConfig holds the resolved configuration for a single AI API call.
|
||||
type AICallConfig struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
MaxTokens int
|
||||
SystemPrompt string
|
||||
}
|
||||
|
||||
// ── Request / Response types (OpenAI-compatible format) ──────────────────────
|
||||
|
||||
type dsMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type dsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []dsMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type dsDelta struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
type dsChoice struct {
|
||||
Delta dsDelta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
type dsSSELine struct {
|
||||
Choices []dsChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
// CallDeepSeekStream sends a messages list to any OpenAI-compatible endpoint
|
||||
// defined in cfg, with stream:true. Pushes each delta content chunk to streamCh.
|
||||
func CallDeepSeekStream(ctx context.Context, cfg AICallConfig, messages []dsMessage, streamCh chan<- string) error {
|
||||
if cfg.APIKey == "" {
|
||||
return fmt.Errorf("API key 未配置,请在设置中填写或联系管理员")
|
||||
}
|
||||
|
||||
payload := dsRequest{
|
||||
Model: cfg.Model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.BaseURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("http request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("upstream status %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return parseDeepSeekSSE(ctx, resp.Body, streamCh)
|
||||
}
|
||||
|
||||
// BuildRAGMessages constructs the OpenAI-compatible messages slice.
|
||||
// If customSystemPrompt is non-empty, it replaces the built-in RAG template.
|
||||
func BuildRAGMessages(knowledgeContext, userQuery, customSystemPrompt string) []dsMessage {
|
||||
var systemContent string
|
||||
if customSystemPrompt != "" {
|
||||
systemContent = customSystemPrompt
|
||||
if knowledgeContext != "" && knowledgeContext != "(无相关本地知识)" {
|
||||
systemContent += "\n\n以下是本地知识库中的相关内容供参考:\n---\n" + knowledgeContext + "\n---"
|
||||
}
|
||||
} else {
|
||||
systemContent = fmt.Sprintf(
|
||||
"你是一位专业的植物养护和客服顾问,擅长用温暖、自然、有亲和力的语气沟通。\n\n"+
|
||||
"以下是来自本地知识库的相关内容,请优先参考:\n\n---\n%s\n---\n\n"+
|
||||
"根据以上知识润色话术,直接输出内容,不加前缀或解释。",
|
||||
knowledgeContext,
|
||||
)
|
||||
}
|
||||
return []dsMessage{
|
||||
{Role: "system", Content: systemContent},
|
||||
{Role: "user", Content: userQuery},
|
||||
}
|
||||
}
|
||||
|
||||
func parseDeepSeekSSE(ctx context.Context, body io.Reader, ch chan<- string) error {
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 64*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
var event dsSSELine
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(event.Choices) > 0 {
|
||||
if chunk := event.Choices[0].Delta.Content; chunk != "" {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"AI-Expert-Sidebar/internal/database"
|
||||
"AI-Expert-Sidebar/internal/models"
|
||||
)
|
||||
|
||||
// ImportResult summarises the outcome of a CSV import.
|
||||
type ImportResult struct {
|
||||
Imported int `json:"imported"`
|
||||
Skipped int `json:"skipped"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ImportCSV reads a CSV file and inserts records into the active knowledge library.
|
||||
//
|
||||
// Required columns (case-insensitive): keyword, question, answer
|
||||
// Optional column: category (defaults to "通用")
|
||||
//
|
||||
// The first row must be the header.
|
||||
func ImportCSV(filePath string) ImportResult {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return ImportResult{Error: fmt.Sprintf("无法打开文件: %v", err)}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
db := database.Get()
|
||||
if db == nil {
|
||||
return ImportResult{Error: "知识库未初始化"}
|
||||
}
|
||||
|
||||
r := csv.NewReader(f)
|
||||
r.TrimLeadingSpace = true
|
||||
r.LazyQuotes = true
|
||||
|
||||
// Read and normalise header
|
||||
header, err := r.Read()
|
||||
if err != nil {
|
||||
return ImportResult{Error: fmt.Sprintf("读取表头失败: %v", err)}
|
||||
}
|
||||
colIdx := make(map[string]int)
|
||||
for i, h := range header {
|
||||
colIdx[strings.ToLower(strings.TrimSpace(h))] = i
|
||||
}
|
||||
for _, required := range []string{"keyword", "question", "answer"} {
|
||||
if _, ok := colIdx[required]; !ok {
|
||||
return ImportResult{Error: fmt.Sprintf("CSV 缺少必需列: %q (需要: keyword, question, answer)", required)}
|
||||
}
|
||||
}
|
||||
catIdx, hasCat := colIdx["category"]
|
||||
|
||||
var imported, skipped int
|
||||
for {
|
||||
row, err := r.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
keyword := strings.TrimSpace(row[colIdx["keyword"]])
|
||||
question := strings.TrimSpace(row[colIdx["question"]])
|
||||
answer := strings.TrimSpace(row[colIdx["answer"]])
|
||||
if keyword == "" || question == "" || answer == "" {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
cat := "通用"
|
||||
if hasCat && catIdx < len(row) {
|
||||
if v := strings.TrimSpace(row[catIdx]); v != "" {
|
||||
cat = v
|
||||
}
|
||||
}
|
||||
entry := models.Entry{Keyword: keyword, Question: question, Answer: answer, Category: cat}
|
||||
if err := db.Create(&entry).Error; err != nil {
|
||||
skipped++
|
||||
} else {
|
||||
imported++
|
||||
}
|
||||
}
|
||||
return ImportResult{Imported: imported, Skipped: skipped}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"AI-Expert-Sidebar/internal/database"
|
||||
"AI-Expert-Sidebar/internal/models"
|
||||
)
|
||||
|
||||
// ListLibraries returns all registered knowledge libraries with entry counts.
|
||||
func ListLibraries() ([]models.Library, error) {
|
||||
sdb := database.GetSettings()
|
||||
if sdb == nil {
|
||||
return nil, fmt.Errorf("settings DB not ready")
|
||||
}
|
||||
var libs []models.Library
|
||||
if err := sdb.Order("created_at asc").Find(&libs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Populate entry count for each library
|
||||
for i, lib := range libs {
|
||||
libs[i].EntryCount = countEntries(lib.FilePath)
|
||||
}
|
||||
return libs, nil
|
||||
}
|
||||
|
||||
// CreateLibrary registers a new knowledge library and creates its SQLite file.
|
||||
func CreateLibrary(name, description string) (*models.Library, error) {
|
||||
sdb := database.GetSettings()
|
||||
dir := database.DataDir
|
||||
|
||||
fileName := sanitizeFileName(name) + ".db"
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
|
||||
// Ensure uniqueness of file path
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
filePath = filepath.Join(dir, sanitizeFileName(name)+"_"+fmt.Sprintf("%d", time.Now().Unix())+".db")
|
||||
}
|
||||
|
||||
if err := database.NewLibraryDB(filePath); err != nil {
|
||||
return nil, fmt.Errorf("create library DB: %w", err)
|
||||
}
|
||||
|
||||
lib := models.Library{Name: name, Description: description, FilePath: filePath}
|
||||
if err := sdb.Create(&lib).Error; err != nil {
|
||||
os.Remove(filePath) // rollback file
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("[Library] Created: %s → %s", name, filePath)
|
||||
return &lib, nil
|
||||
}
|
||||
|
||||
// SwitchLibrary makes the named library active.
|
||||
func SwitchLibrary(name string) error {
|
||||
sdb := database.GetSettings()
|
||||
var lib models.Library
|
||||
if err := sdb.Where("name = ?", name).First(&lib).Error; err != nil {
|
||||
return fmt.Errorf("library %q not found", name)
|
||||
}
|
||||
return database.OpenLibrary(lib)
|
||||
}
|
||||
|
||||
// DeleteLibrary removes a library from the registry (and optionally its file).
|
||||
func DeleteLibrary(name string, deleteFile bool) error {
|
||||
sdb := database.GetSettings()
|
||||
var lib models.Library
|
||||
if err := sdb.Where("name = ?", name).First(&lib).Error; err != nil {
|
||||
return fmt.Errorf("library %q not found", name)
|
||||
}
|
||||
if err := sdb.Delete(&lib).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if deleteFile {
|
||||
return os.Remove(lib.FilePath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitLibraries restores the last active library or creates the default one.
|
||||
func InitLibraries() error {
|
||||
sdb := database.GetSettings()
|
||||
// Check active_library preference
|
||||
var setting models.AppSetting
|
||||
if sdb.Where("key = ?", "active_library").First(&setting).Error == nil {
|
||||
if err := SwitchLibrary(setting.Value); err == nil {
|
||||
return nil // restored successfully
|
||||
}
|
||||
}
|
||||
// No preference or stale — find first library
|
||||
var lib models.Library
|
||||
if sdb.Order("created_at asc").First(&lib).Error == nil {
|
||||
return database.OpenLibrary(lib)
|
||||
}
|
||||
// No libraries at all — create default
|
||||
lib2, err := CreateLibrary("默认知识库", "自动创建的默认知识库")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return database.OpenLibrary(*lib2)
|
||||
}
|
||||
|
||||
// ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func countEntries(filePath string) int {
|
||||
db, err := database.NewLibraryDBReadOnly(filePath)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
var count int64
|
||||
db.Model(&models.Entry{}).Count(&count)
|
||||
return int(count)
|
||||
}
|
||||
|
||||
func sanitizeFileName(name string) string {
|
||||
result := make([]rune, 0, len(name))
|
||||
for _, r := range name {
|
||||
if r == '/' || r == '\\' || r == ':' || r == '*' || r == '?' || r == '"' || r == '<' || r == '>' || r == '|' {
|
||||
result = append(result, '_')
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return "library"
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"AI-Expert-Sidebar/internal/database"
|
||||
"AI-Expert-Sidebar/internal/models"
|
||||
)
|
||||
|
||||
const maxResults = 5
|
||||
|
||||
var ErrDBUnavailable = errors.New("database unavailable")
|
||||
|
||||
// SearchResult is the DTO returned to the frontend.
|
||||
type SearchResult struct {
|
||||
ID uint `json:"id"`
|
||||
Question string `json:"question"`
|
||||
Answer string `json:"answer"`
|
||||
Category string `json:"category"`
|
||||
Score int `json:"score"` // 2=keyword, 1=question, 0=fallback
|
||||
IsFallback bool `json:"is_fallback"`
|
||||
}
|
||||
|
||||
// SearchKnowledge performs fuzzy search in the active knowledge library.
|
||||
func SearchKnowledge(query string) ([]SearchResult, error) {
|
||||
db := database.Get()
|
||||
if db == nil {
|
||||
return nil, ErrDBUnavailable
|
||||
}
|
||||
if len(query) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type res struct {
|
||||
rows []SearchResult
|
||||
err error
|
||||
}
|
||||
ch := make(chan res, 1)
|
||||
|
||||
go func() {
|
||||
var total int64
|
||||
db.Model(&models.Entry{}).Count(&total)
|
||||
if total == 0 {
|
||||
ch <- res{[]SearchResult{}, nil}
|
||||
return
|
||||
}
|
||||
|
||||
like := "%" + query + "%"
|
||||
var rows []models.Entry
|
||||
err := db.Where("keyword LIKE ? OR question LIKE ?", like, like).
|
||||
Order("updated_at DESC").Limit(maxResults).Find(&rows).Error
|
||||
if err != nil {
|
||||
ch <- res{nil, err}
|
||||
return
|
||||
}
|
||||
|
||||
isFallback := len(rows) == 0
|
||||
if isFallback {
|
||||
log.Printf("[Search] No match for %q, returning fallback", query)
|
||||
db.Order("updated_at DESC").Limit(3).Find(&rows)
|
||||
}
|
||||
|
||||
out := make([]SearchResult, 0, len(rows))
|
||||
for _, r := range rows {
|
||||
score := 0
|
||||
if !isFallback {
|
||||
if containsIgnoreCase(r.Keyword, query) {
|
||||
score = 2
|
||||
} else if containsIgnoreCase(r.Question, query) {
|
||||
score = 1
|
||||
}
|
||||
}
|
||||
out = append(out, SearchResult{
|
||||
ID: r.ID, Question: r.Question, Answer: r.Answer,
|
||||
Category: r.Category, Score: score, IsFallback: isFallback,
|
||||
})
|
||||
}
|
||||
ch <- res{out, nil}
|
||||
}()
|
||||
|
||||
r := <-ch
|
||||
return r.rows, r.err
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, sub string) bool {
|
||||
if len(sub) == 0 {
|
||||
return true
|
||||
}
|
||||
sl, subl := []rune(s), []rune(sub)
|
||||
if len(sl) < len(subl) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(sl)-len(subl); i++ {
|
||||
match := true
|
||||
for j, c := range subl {
|
||||
if toLower(sl[i+j]) != toLower(c) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toLower(r rune) rune {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return r + 32
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"AI-Expert-Sidebar/internal/config"
|
||||
"AI-Expert-Sidebar/internal/crypto"
|
||||
"AI-Expert-Sidebar/internal/database"
|
||||
"AI-Expert-Sidebar/internal/models"
|
||||
)
|
||||
|
||||
// SettingsDTO is what the frontend reads and writes.
|
||||
type SettingsDTO struct {
|
||||
AIProvider string `json:"ai_provider"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
UsePublicKey bool `json:"use_public_key"`
|
||||
}
|
||||
|
||||
// GetSettings reads AI config from settings.db key-value store.
|
||||
func GetSettings() *SettingsDTO {
|
||||
sdb := database.GetSettings()
|
||||
if sdb == nil {
|
||||
return defaultDTO()
|
||||
}
|
||||
var rows []models.AppSetting
|
||||
sdb.Find(&rows)
|
||||
m := make(map[string]string, len(rows))
|
||||
for _, r := range rows {
|
||||
m[r.Key] = r.Value
|
||||
}
|
||||
|
||||
apiKey, _ := crypto.DecryptAPIKey(m["api_key_encrypted"])
|
||||
maxTokens := 1024
|
||||
fmt.Sscanf(m["max_tokens"], "%d", &maxTokens)
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 1024
|
||||
}
|
||||
return &SettingsDTO{
|
||||
AIProvider: strOr(m["ai_provider"], "deepseek"),
|
||||
BaseURL: m["base_url"],
|
||||
APIKey: apiKey,
|
||||
Model: strOr(m["model"], "deepseek-chat"),
|
||||
SystemPrompt: m["system_prompt"],
|
||||
MaxTokens: maxTokens,
|
||||
UsePublicKey: m["use_public_key"] != "false",
|
||||
}
|
||||
}
|
||||
|
||||
// SaveSettings persists AI config into settings.db.
|
||||
func SaveSettings(dto SettingsDTO) error {
|
||||
sdb := database.GetSettings()
|
||||
if sdb == nil {
|
||||
return fmt.Errorf("settings DB not ready")
|
||||
}
|
||||
upsert := func(k, v string) {
|
||||
sdb.Exec("INSERT INTO app_settings(key,value) VALUES(?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value", k, v)
|
||||
}
|
||||
upsert("ai_provider", dto.AIProvider)
|
||||
upsert("base_url", dto.BaseURL)
|
||||
upsert("model", dto.Model)
|
||||
upsert("system_prompt", dto.SystemPrompt)
|
||||
upsert("max_tokens", fmt.Sprintf("%d", dto.MaxTokens))
|
||||
usePublic := "true"
|
||||
if !dto.UsePublicKey {
|
||||
usePublic = "false"
|
||||
}
|
||||
upsert("use_public_key", usePublic)
|
||||
if !dto.UsePublicKey && dto.APIKey != "" {
|
||||
enc, err := crypto.EncryptAPIKey(dto.APIKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upsert("api_key_encrypted", enc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveAIConfig returns the effective AI call config (local settings or global fallback).
|
||||
func ResolveAIConfig() AICallConfig {
|
||||
base := AICallConfig{
|
||||
BaseURL: "https://api.deepseek.com/chat/completions",
|
||||
APIKey: config.Global.DeepSeek.APIKey,
|
||||
Model: config.Global.DeepSeek.Model,
|
||||
MaxTokens: config.Global.DeepSeek.MaxTokens,
|
||||
}
|
||||
dto := GetSettings()
|
||||
if dto == nil {
|
||||
return base
|
||||
}
|
||||
base.SystemPrompt = dto.SystemPrompt
|
||||
if dto.UsePublicKey || dto.APIKey == "" {
|
||||
return base
|
||||
}
|
||||
providerURL := dto.BaseURL
|
||||
if providerURL == "" {
|
||||
switch dto.AIProvider {
|
||||
case "deepseek":
|
||||
providerURL = "https://api.deepseek.com/chat/completions"
|
||||
case "openai":
|
||||
providerURL = "https://api.openai.com/v1/chat/completions"
|
||||
case "grok":
|
||||
providerURL = "https://api.x.ai/v1/chat/completions"
|
||||
}
|
||||
}
|
||||
maxTok := dto.MaxTokens
|
||||
if maxTok <= 0 {
|
||||
maxTok = 1024
|
||||
}
|
||||
return AICallConfig{
|
||||
BaseURL: strOr(providerURL, base.BaseURL),
|
||||
APIKey: dto.APIKey,
|
||||
Model: strOr(dto.Model, base.Model),
|
||||
MaxTokens: maxTok,
|
||||
SystemPrompt: dto.SystemPrompt,
|
||||
}
|
||||
}
|
||||
|
||||
func strOr(v, def string) string {
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func defaultDTO() *SettingsDTO {
|
||||
return &SettingsDTO{AIProvider: "deepseek", Model: "deepseek-chat", MaxTokens: 1024, UsePublicKey: true}
|
||||
}
|
||||
Reference in New Issue
Block a user