189 lines
5.3 KiB
Go
189 lines
5.3 KiB
Go
package config
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"engimind/internal/models"
|
|
"engimind/internal/vector"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// ConfigService manages global LLM and VectorDB configuration.
|
|
// Bound to Wails as a service.
|
|
type ConfigService struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewConfigService initializes the global config database.
|
|
func NewConfigService() *ConfigService {
|
|
return &ConfigService{}
|
|
}
|
|
|
|
// OnStartup is called by Wails on app start.
|
|
func (s *ConfigService) Init() error {
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return fmt.Errorf("get home dir: %w", err)
|
|
}
|
|
dbDir := filepath.Join(homeDir, ".engimind")
|
|
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
|
return fmt.Errorf("create config dir: %w", err)
|
|
}
|
|
|
|
db, err := gorm.Open(sqlite.Open(filepath.Join(dbDir, "global.db")), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("open global db: %w", err)
|
|
}
|
|
s.db = db
|
|
|
|
if err := db.AutoMigrate(
|
|
&models.LLMProvider{},
|
|
&models.VectorDBConfig{},
|
|
&models.Project{},
|
|
); err != nil {
|
|
return fmt.Errorf("auto migrate: %w", err)
|
|
}
|
|
|
|
// Seed defaults if empty
|
|
var count int64
|
|
db.Model(&models.LLMProvider{}).Count(&count)
|
|
if count == 0 {
|
|
s.seedDefaults()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *ConfigService) seedDefaults() {
|
|
providers := []models.LLMProvider{
|
|
{ID: "cfg1", Name: "DeepSeek Cloud", Provider: "DeepSeek", BaseURL: "https://api.deepseek.com", APIKey: "", ModelID: "deepseek-reasoner", Enabled: true},
|
|
{ID: "cfg2", Name: "Ollama: Local", Provider: "Ollama", BaseURL: "http://localhost:11434", APIKey: "", ModelID: "qwen2.5:32b", Enabled: true},
|
|
}
|
|
for _, p := range providers {
|
|
s.db.Create(&p)
|
|
}
|
|
s.db.Create(&models.VectorDBConfig{
|
|
ID: 1, Endpoint: "http://localhost:6333", APIKey: "", Status: "disconnected",
|
|
})
|
|
}
|
|
|
|
// --- LLM Provider CRUD ---
|
|
|
|
func (s *ConfigService) GetAllProviders() ([]models.LLMProvider, error) {
|
|
var providers []models.LLMProvider
|
|
err := s.db.Find(&providers).Error
|
|
return providers, err
|
|
}
|
|
|
|
func (s *ConfigService) SaveProvider(p models.LLMProvider) error {
|
|
return s.db.Save(&p).Error
|
|
}
|
|
|
|
func (s *ConfigService) DeleteProvider(id string) error {
|
|
return s.db.Delete(&models.LLMProvider{}, "id = ?", id).Error
|
|
}
|
|
|
|
// --- VectorDB Config ---
|
|
|
|
func (s *ConfigService) GetVectorDBConfig() (models.VectorDBConfig, error) {
|
|
var cfg models.VectorDBConfig
|
|
err := s.db.First(&cfg).Error
|
|
return cfg, err
|
|
}
|
|
|
|
func (s *ConfigService) SaveVectorDBConfig(c models.VectorDBConfig) error {
|
|
c.ID = 1 // singleton
|
|
return s.db.Save(&c).Error
|
|
}
|
|
|
|
func (s *ConfigService) GetDB() *gorm.DB {
|
|
return s.db
|
|
}
|
|
|
|
// --- Connection Testing ---
|
|
|
|
// TestVectorDBConnection verifies the Qdrant server is reachable.
|
|
func (s *ConfigService) TestVectorDBConnection(endpoint string) (bool, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Normalize endpoint for gRPC (Qdrant default gRPC is 6334, HTTP is 6333)
|
|
// Strip http:// or https:// prefix
|
|
if len(endpoint) > 7 && endpoint[:7] == "http://" {
|
|
endpoint = endpoint[7:]
|
|
} else if len(endpoint) > 8 && endpoint[:8] == "https://" {
|
|
endpoint = endpoint[8:]
|
|
}
|
|
|
|
// If user inputs 6333 (HTTP port), we auto-correct to 6334 for gRPC for ease of use
|
|
if len(endpoint) > 5 && endpoint[len(endpoint)-5:] == ":6333" {
|
|
endpoint = endpoint[:len(endpoint)-5] + ":6334"
|
|
}
|
|
|
|
store, err := vector.NewQdrantStore(endpoint, 1) // dimension is mostly ignored for list operation
|
|
if err != nil {
|
|
return false, fmt.Errorf("建立到 Qdrant 的连接失败 (尝试连接 %s): %w", endpoint, err)
|
|
}
|
|
defer store.Close()
|
|
|
|
ok, err := store.TestConnection(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("Qdrant 测试请求错误 (请确保使用的是 gRPC 端口, 默认 6334): %w", err)
|
|
}
|
|
return ok, nil
|
|
}
|
|
|
|
// TestLLMConnection tests connectivity and auth configuration for a provider.
|
|
func (s *ConfigService) TestLLMConnection(provider, baseURL, apiKey string) (bool, error) {
|
|
var reqURL string
|
|
if provider == "Ollama" {
|
|
reqURL = baseURL + "/api/tags"
|
|
} else {
|
|
// DeepSeek/OpenAI compatible models endpoint
|
|
// Often requires /v1/models. If user provides baseURL with /v1, we just append /models
|
|
// Or we can just let it try the URL directly if we assume user's base URL is the completions one?
|
|
// Usually standard is "https://api.deepseek.com" or "https://api.deepseek.com/v1"
|
|
// Wait, some BaseURLs don't have /v1. Let's do a basic normalization.
|
|
if baseURL[len(baseURL)-1] == '/' {
|
|
baseURL = baseURL[:len(baseURL)-1]
|
|
}
|
|
if provider == "DeepSeek" || provider == "OpenAI" {
|
|
reqURL = baseURL + "/models"
|
|
} else {
|
|
reqURL = baseURL + "/v1/models"
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(context.Background(), "GET", reqURL, nil)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if apiKey != "" && provider != "Ollama" {
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
}
|
|
|
|
client := &http.Client{Timeout: 8 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return false, fmt.Errorf("网络连通性异常: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
|
return true, nil
|
|
}
|
|
if resp.StatusCode == 401 {
|
|
return false, fmt.Errorf("API Key 无效 (HTTP 无授权 %d)", resp.StatusCode)
|
|
}
|
|
return false, fmt.Errorf("返回非健康状态码: HTTP %d", resp.StatusCode)
|
|
}
|