init: initial commit
This commit is contained in:
@@ -0,0 +1,239 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"engimind/internal/config"
|
||||
"engimind/internal/project"
|
||||
"engimind/internal/vector"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
)
|
||||
|
||||
// ChatService handles chat interactions and chapter generation (A+C→B).
|
||||
type ChatService struct {
|
||||
llm *LLMClient
|
||||
rag *vector.RAGService
|
||||
configSvc *config.ConfigService
|
||||
projectSvc *project.ProjectService
|
||||
}
|
||||
|
||||
// NewChatService creates a chat service.
|
||||
func NewChatService(
|
||||
configSvc *config.ConfigService,
|
||||
projectSvc *project.ProjectService,
|
||||
rag *vector.RAGService,
|
||||
) *ChatService {
|
||||
return &ChatService{
|
||||
llm: NewLLMClient(),
|
||||
rag: rag,
|
||||
configSvc: configSvc,
|
||||
projectSvc: projectSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessage handles a user chat message with RAG context.
|
||||
func (s *ChatService) SendMessage(content string, selectedFileIDs []string, modelID string) (string, error) {
|
||||
providers, _ := s.configSvc.GetAllProviders()
|
||||
var provider *struct{ URL, Key, Model, ProviderType string }
|
||||
for _, p := range providers {
|
||||
if p.ID == modelID && p.Enabled {
|
||||
provider = &struct{ URL, Key, Model, ProviderType string }{p.BaseURL, p.APIKey, p.ModelID, p.Provider}
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
return "", fmt.Errorf("model %s not found or disabled", modelID)
|
||||
}
|
||||
|
||||
// RAG: search context
|
||||
projectID := s.projectSvc.GetCurrentProjectID()
|
||||
var contextText string
|
||||
if projectID != "" && len(selectedFileIDs) > 0 {
|
||||
embCfg := s.getEmbeddingConfig()
|
||||
chunks, err := s.rag.SearchContext(context.Background(), projectID, content, 5, embCfg)
|
||||
if err != nil {
|
||||
slog.Warn("RAG search failed, proceeding without context", "err", err)
|
||||
} else {
|
||||
var parts []string
|
||||
for _, c := range chunks {
|
||||
parts = append(parts, c.Text)
|
||||
}
|
||||
contextText = strings.Join(parts, "\n\n---\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "你是一位专业的工程技术助手。基于提供的工程素材回答问题,引用来源时使用 [N] 标注。"},
|
||||
}
|
||||
if contextText != "" {
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("参考以下工程素材:\n\n%s\n\n---\n\n用户提问:%s", contextText, content),
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, Message{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
resp, err := s.llm.Complete(provider.URL, provider.Key, provider.Model, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response from model")
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// StreamMessage handles a user chat message and streams the response via Wails events.
|
||||
func (s *ChatService) StreamMessage(content string, selectedFileIDs []string, modelID string, messageID string) (string, error) {
|
||||
providers, _ := s.configSvc.GetAllProviders()
|
||||
var provider *struct{ URL, Key, Model, ProviderType string }
|
||||
for _, p := range providers {
|
||||
if p.ID == modelID && p.Enabled {
|
||||
provider = &struct{ URL, Key, Model, ProviderType string }{p.BaseURL, p.APIKey, p.ModelID, p.Provider}
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
return "", fmt.Errorf("model %s not found or disabled", modelID)
|
||||
}
|
||||
|
||||
projectID := s.projectSvc.GetCurrentProjectID()
|
||||
var contextText string
|
||||
if projectID != "" && len(selectedFileIDs) > 0 {
|
||||
embCfg := s.getEmbeddingConfig()
|
||||
chunks, err := s.rag.SearchContext(context.Background(), projectID, content, 5, embCfg)
|
||||
if err == nil {
|
||||
var parts []string
|
||||
for _, c := range chunks {
|
||||
parts = append(parts, c.Text)
|
||||
}
|
||||
contextText = strings.Join(parts, "\n\n---\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "你是一位专业的工程技术助手。基于提供的工程素材回答问题,引用来源时使用 [N] 标注。"},
|
||||
}
|
||||
if contextText != "" {
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("参考以下工程素材:\n\n%s\n\n---\n\n用户提问:%s", contextText, content),
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, Message{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
var fullText string
|
||||
err := s.llm.StreamComplete(provider.URL, provider.Key, provider.Model, messages, func(chunk string) {
|
||||
fullText += chunk
|
||||
application.Get().Event.Emit("chat_stream_"+messageID, fullText)
|
||||
})
|
||||
return fullText, err
|
||||
}
|
||||
|
||||
// GenerateChapter implements the A+C→B logic for a single chapter.
|
||||
func (s *ChatService) GenerateChapter(chapterTitle string, selectedFileIDs []string, modelID string) (string, error) {
|
||||
providers, _ := s.configSvc.GetAllProviders()
|
||||
var provider *struct{ URL, Key, Model string }
|
||||
for _, p := range providers {
|
||||
if p.ID == modelID && p.Enabled {
|
||||
provider = &struct{ URL, Key, Model string }{p.BaseURL, p.APIKey, p.ModelID}
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
return "", fmt.Errorf("model %s not found or disabled", modelID)
|
||||
}
|
||||
|
||||
// RAG: search context for chapter topic
|
||||
projectID := s.projectSvc.GetCurrentProjectID()
|
||||
embCfg := s.getEmbeddingConfig()
|
||||
chunks, err := s.rag.SearchContext(context.Background(), projectID, chapterTitle, 8, embCfg)
|
||||
if err != nil {
|
||||
slog.Warn("RAG search failed for chapter generation", "err", err)
|
||||
}
|
||||
|
||||
var contextParts []string
|
||||
for _, c := range chunks {
|
||||
contextParts = append(contextParts, c.Text)
|
||||
}
|
||||
contextText := strings.Join(contextParts, "\n\n---\n\n")
|
||||
|
||||
prompt := fmt.Sprintf(
|
||||
"你是工程报告撰写专家。请根据以下工程素材,按照章节要求撰写报告内容。\n\n"+
|
||||
"## 章节要求\n%s\n\n"+
|
||||
"## 参考素材\n%s\n\n"+
|
||||
"## 输出要求\n"+
|
||||
"1. 使用 Markdown 格式\n"+
|
||||
"2. 引用素材时使用 [N] 标注\n"+
|
||||
"3. 内容专业、结构清晰\n"+
|
||||
"4. 包含具体数据和分析结论",
|
||||
chapterTitle, contextText,
|
||||
)
|
||||
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "你是一位资深工程报告撰写专家,擅长根据工程素材生成结构化的技术报告章节。"},
|
||||
{Role: "user", Content: prompt},
|
||||
}
|
||||
|
||||
resp, err := s.llm.Complete(provider.URL, provider.Key, provider.Model, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response from model")
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// StreamTemplateDirectory uses the active LLM to extract a structured chapter outline from delivery standard text, streaming to frontend.
|
||||
func (s *ChatService) StreamTemplateDirectory(content string, modelID string, messageID string) (string, error) {
|
||||
providers, _ := s.configSvc.GetAllProviders()
|
||||
var provider *struct{ URL, Key, Model, ProviderType string }
|
||||
for _, p := range providers {
|
||||
if p.ID == modelID && p.Enabled {
|
||||
provider = &struct{ URL, Key, Model, ProviderType string }{p.BaseURL, p.APIKey, p.ModelID, p.Provider}
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
return "", fmt.Errorf("model %s not found or disabled", modelID)
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(
|
||||
"你是一个工程标准的目录解析助手。请从下面提供的交付标准文本中提取工程的主干章节,并以 JSON 数组的格式返回。\n\n"+
|
||||
"### 要求:\n"+
|
||||
"1. 只返回 JSON 数组,不包含其他废话或者回答前缀。\n"+
|
||||
"2. 输出格式必须严格符合:[{ \"id\": \"chapter-1\", \"title\": \"1. 原材料进场检验\", \"content\": \"...如果标准里有简要描述可附上\" }]\n\n"+
|
||||
"### 交付标准内容:\n%s",
|
||||
content,
|
||||
)
|
||||
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "你是一个专业的结构化数据抽取工具。你只输出合法的 JSON,不要使用 Markdown 代码块包裹,也不要给出任何其他解释。"},
|
||||
{Role: "user", Content: prompt},
|
||||
}
|
||||
|
||||
var fullText string
|
||||
err := s.llm.StreamComplete(provider.URL, provider.Key, provider.Model, messages, func(chunk string) {
|
||||
fullText += chunk
|
||||
application.Get().Event.Emit("chat_stream_"+messageID, fullText)
|
||||
})
|
||||
|
||||
return fullText, err
|
||||
}
|
||||
|
||||
func (s *ChatService) getEmbeddingConfig() vector.EmbeddingConfig {
|
||||
// Use bge-m3 via Ollama as default embedding model
|
||||
return vector.EmbeddingConfig{
|
||||
BaseURL: "http://localhost:11434",
|
||||
Model: "bge-m3",
|
||||
APIKey: "",
|
||||
Provider: "Ollama",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LLMClient provides a unified interface for OpenAI-compatible LLM APIs.
|
||||
type LLMClient struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewLLMClient creates a new LLM client.
|
||||
func NewLLMClient() *LLMClient {
|
||||
return &LLMClient{
|
||||
client: &http.Client{Timeout: 120 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Message represents a chat message in OpenAI format.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest is the OpenAI chat completion request.
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse is a non-streaming response.
|
||||
type ChatResponse struct {
|
||||
Choices []struct {
|
||||
Message Message `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// StreamDelta from SSE streaming.
|
||||
type streamChunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
// Complete sends a non-streaming chat request.
|
||||
func (c *LLMClient) Complete(baseURL, apiKey, model string, messages []Message) (*ChatResponse, error) {
|
||||
reqBody := ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req, _ := http.NewRequest("POST", baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llm request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
var result ChatResponse
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse llm response: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// StreamComplete sends a streaming chat request, calling onChunk for each token.
|
||||
func (c *LLMClient) StreamComplete(baseURL, apiKey, model string, messages []Message, onChunk func(string)) error {
|
||||
reqBody := ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req, _ := http.NewRequest("POST", baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llm stream request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk streamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
onChunk(chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"engimind/internal/models"
|
||||
"engimind/internal/vector"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// ConfigService manages global LLM and VectorDB configuration.
|
||||
// Bound to Wails as a service.
|
||||
type ConfigService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewConfigService initializes the global config database.
|
||||
func NewConfigService() *ConfigService {
|
||||
return &ConfigService{}
|
||||
}
|
||||
|
||||
// OnStartup is called by Wails on app start.
|
||||
func (s *ConfigService) Init() error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get home dir: %w", err)
|
||||
}
|
||||
dbDir := filepath.Join(homeDir, ".engimind")
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return fmt.Errorf("create config dir: %w", err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(filepath.Join(dbDir, "global.db")), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("open global db: %w", err)
|
||||
}
|
||||
s.db = db
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&models.LLMProvider{},
|
||||
&models.VectorDBConfig{},
|
||||
&models.Project{},
|
||||
); err != nil {
|
||||
return fmt.Errorf("auto migrate: %w", err)
|
||||
}
|
||||
|
||||
// Seed defaults if empty
|
||||
var count int64
|
||||
db.Model(&models.LLMProvider{}).Count(&count)
|
||||
if count == 0 {
|
||||
s.seedDefaults()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ConfigService) seedDefaults() {
|
||||
providers := []models.LLMProvider{
|
||||
{ID: "cfg1", Name: "DeepSeek Cloud", Provider: "DeepSeek", BaseURL: "https://api.deepseek.com", APIKey: "", ModelID: "deepseek-reasoner", Enabled: true},
|
||||
{ID: "cfg2", Name: "Ollama: Local", Provider: "Ollama", BaseURL: "http://localhost:11434", APIKey: "", ModelID: "qwen2.5:32b", Enabled: true},
|
||||
}
|
||||
for _, p := range providers {
|
||||
s.db.Create(&p)
|
||||
}
|
||||
s.db.Create(&models.VectorDBConfig{
|
||||
ID: 1, Endpoint: "http://localhost:6333", APIKey: "", Status: "disconnected",
|
||||
})
|
||||
}
|
||||
|
||||
// --- LLM Provider CRUD ---
|
||||
|
||||
func (s *ConfigService) GetAllProviders() ([]models.LLMProvider, error) {
|
||||
var providers []models.LLMProvider
|
||||
err := s.db.Find(&providers).Error
|
||||
return providers, err
|
||||
}
|
||||
|
||||
func (s *ConfigService) SaveProvider(p models.LLMProvider) error {
|
||||
return s.db.Save(&p).Error
|
||||
}
|
||||
|
||||
func (s *ConfigService) DeleteProvider(id string) error {
|
||||
return s.db.Delete(&models.LLMProvider{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// --- VectorDB Config ---
|
||||
|
||||
func (s *ConfigService) GetVectorDBConfig() (models.VectorDBConfig, error) {
|
||||
var cfg models.VectorDBConfig
|
||||
err := s.db.First(&cfg).Error
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func (s *ConfigService) SaveVectorDBConfig(c models.VectorDBConfig) error {
|
||||
c.ID = 1 // singleton
|
||||
return s.db.Save(&c).Error
|
||||
}
|
||||
|
||||
func (s *ConfigService) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// --- Connection Testing ---
|
||||
|
||||
// TestVectorDBConnection verifies the Qdrant server is reachable.
|
||||
func (s *ConfigService) TestVectorDBConnection(endpoint string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Normalize endpoint for gRPC (Qdrant default gRPC is 6334, HTTP is 6333)
|
||||
// Strip http:// or https:// prefix
|
||||
if len(endpoint) > 7 && endpoint[:7] == "http://" {
|
||||
endpoint = endpoint[7:]
|
||||
} else if len(endpoint) > 8 && endpoint[:8] == "https://" {
|
||||
endpoint = endpoint[8:]
|
||||
}
|
||||
|
||||
// If user inputs 6333 (HTTP port), we auto-correct to 6334 for gRPC for ease of use
|
||||
if len(endpoint) > 5 && endpoint[len(endpoint)-5:] == ":6333" {
|
||||
endpoint = endpoint[:len(endpoint)-5] + ":6334"
|
||||
}
|
||||
|
||||
store, err := vector.NewQdrantStore(endpoint, 1) // dimension is mostly ignored for list operation
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("建立到 Qdrant 的连接失败 (尝试连接 %s): %w", endpoint, err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ok, err := store.TestConnection(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Qdrant 测试请求错误 (请确保使用的是 gRPC 端口, 默认 6334): %w", err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// TestLLMConnection tests connectivity and auth configuration for a provider.
|
||||
func (s *ConfigService) TestLLMConnection(provider, baseURL, apiKey string) (bool, error) {
|
||||
var reqURL string
|
||||
if provider == "Ollama" {
|
||||
reqURL = baseURL + "/api/tags"
|
||||
} else {
|
||||
// DeepSeek/OpenAI compatible models endpoint
|
||||
// Often requires /v1/models. If user provides baseURL with /v1, we just append /models
|
||||
// Or we can just let it try the URL directly if we assume user's base URL is the completions one?
|
||||
// Usually standard is "https://api.deepseek.com" or "https://api.deepseek.com/v1"
|
||||
// Wait, some BaseURLs don't have /v1. Let's do a basic normalization.
|
||||
if baseURL[len(baseURL)-1] == '/' {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
if provider == "DeepSeek" || provider == "OpenAI" {
|
||||
reqURL = baseURL + "/models"
|
||||
} else {
|
||||
reqURL = baseURL + "/v1/models"
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if apiKey != "" && provider != "Ollama" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 8 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("网络连通性异常: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||||
return true, nil
|
||||
}
|
||||
if resp.StatusCode == 401 {
|
||||
return false, fmt.Errorf("API Key 无效 (HTTP 无授权 %d)", resp.StatusCode)
|
||||
}
|
||||
return false, fmt.Errorf("返回非健康状态码: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Global DB Models ---
|
||||
|
||||
// LLMProvider stores LLM provider configuration
|
||||
type LLMProvider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"` // Ollama, DeepSeek, OpenAI, Qwen
|
||||
BaseURL string `json:"url"`
|
||||
APIKey string `json:"key"`
|
||||
ModelID string `json:"model"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// VectorDBConfig stores Qdrant connection settings
|
||||
type VectorDBConfig struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Status string `json:"status"` // connected, disconnected
|
||||
}
|
||||
|
||||
// Project represents a top-level engineering project
|
||||
type Project struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"` // project DB file path
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// --- Project-scoped DB Models ---
|
||||
|
||||
// SourceFile represents an imported engineering document
|
||||
type SourceFile struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProjectID string `gorm:"index" json:"projectId"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // pdf, cad, gis, excel, word
|
||||
Category string `json:"category"`
|
||||
FilePath string `json:"filePath"`
|
||||
Size string `json:"size"`
|
||||
ParsedContent string `json:"parsedContent,omitempty"`
|
||||
VectorStatus string `json:"vectorStatus"` // pending, processing, done, error
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ChatMessage stores conversation history per project
|
||||
type ChatMessage struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProjectID string `gorm:"index" json:"projectId"`
|
||||
Role string `json:"role"` // user, assistant
|
||||
Content string `json:"content"`
|
||||
Sources string `json:"sources,omitempty"` // JSON array of source names
|
||||
Citations string `json:"citations,omitempty"` // JSON array of citation objects
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// TemplateChapter stores delivery template chapters per project
|
||||
type TemplateChapter struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProjectID string `gorm:"index" json:"projectId"`
|
||||
TemplateName string `json:"templateName"`
|
||||
Title string `json:"title"`
|
||||
Status string `json:"status"` // idle, loading, done
|
||||
Progress int `json:"progress"`
|
||||
Content string `json:"content"`
|
||||
SortOrder int `json:"sortOrder"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TextChunk represents a vectorized text segment
|
||||
type TextChunk struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProjectID string `gorm:"index" json:"projectId"`
|
||||
SourceID string `gorm:"index" json:"sourceId"`
|
||||
Content string `json:"content"`
|
||||
ChunkIdx int `json:"chunkIdx"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CADParser is a stub parser for CAD/DWG files.
|
||||
// Real implementation requires CGO/GDAL — reserved for future development.
|
||||
type CADParser struct{}
|
||||
|
||||
func (p *CADParser) SupportedExtensions() []string {
|
||||
return []string{".dwg", ".dxf"}
|
||||
}
|
||||
|
||||
func (p *CADParser) ParseToMarkdown(path string) (string, error) {
|
||||
// Stub: simulate extracting layer names and annotations
|
||||
var sb strings.Builder
|
||||
sb.WriteString("## CAD 图纸解析结果 (模拟)\n\n")
|
||||
sb.WriteString(fmt.Sprintf("**文件**: %s\n\n", path))
|
||||
sb.WriteString("### 图层列表\n\n")
|
||||
sb.WriteString("| 图层名 | 类型 | 元素数 |\n")
|
||||
sb.WriteString("| --- | --- | --- |\n")
|
||||
sb.WriteString("| STR_MAIN | 结构主体 | 142 |\n")
|
||||
sb.WriteString("| DIM_TEXT | 标注文字 | 87 |\n")
|
||||
sb.WriteString("| SEAL_V3 | 密封层 | 23 |\n\n")
|
||||
sb.WriteString("### 标注摘要\n\n")
|
||||
sb.WriteString("- 预留缝宽度: 2.5mm\n")
|
||||
sb.WriteString("- 坐标基点: X=1240, Y=442\n")
|
||||
sb.WriteString("\n> ⚠️ 本解析为模拟结果,完整 CAD 解析需集成 GDAL/LibreDWG。\n")
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/xuri/excelize/v2"
|
||||
)
|
||||
|
||||
// ExcelParser extracts all sheets from Excel files as Markdown tables.
|
||||
type ExcelParser struct{}
|
||||
|
||||
func (p *ExcelParser) SupportedExtensions() []string {
|
||||
return []string{".xlsx", ".xls"}
|
||||
}
|
||||
|
||||
func (p *ExcelParser) ParseToMarkdown(path string) (string, error) {
|
||||
f, err := excelize.OpenFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open excel: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
for _, sheet := range f.GetSheetList() {
|
||||
rows, err := f.GetRows(sheet)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("## Sheet: %s\n\n", sheet))
|
||||
|
||||
// Header row
|
||||
sb.WriteString("| ")
|
||||
for _, cell := range rows[0] {
|
||||
sb.WriteString(cell + " | ")
|
||||
}
|
||||
sb.WriteString("\n|")
|
||||
for range rows[0] {
|
||||
sb.WriteString(" --- |")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Data rows
|
||||
for _, row := range rows[1:] {
|
||||
sb.WriteString("| ")
|
||||
for i := 0; i < len(rows[0]); i++ {
|
||||
if i < len(row) {
|
||||
sb.WriteString(row[i])
|
||||
}
|
||||
sb.WriteString(" | ")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parser interface for all document parsers.
|
||||
type Parser interface {
|
||||
SupportedExtensions() []string
|
||||
ParseToMarkdown(path string) (string, error)
|
||||
}
|
||||
|
||||
// Registry holds registered parsers keyed by extension.
|
||||
type Registry struct {
|
||||
parsers map[string]Parser
|
||||
}
|
||||
|
||||
// NewRegistry creates a parser registry with all built-in parsers.
|
||||
func NewRegistry() *Registry {
|
||||
r := &Registry{parsers: make(map[string]Parser)}
|
||||
r.Register(&ExcelParser{})
|
||||
r.Register(&PDFParser{})
|
||||
r.Register(&CADParser{})
|
||||
r.Register(&WordParser{})
|
||||
return r
|
||||
}
|
||||
|
||||
// Register adds a parser for its supported extensions.
|
||||
func (r *Registry) Register(p Parser) {
|
||||
for _, ext := range p.SupportedExtensions() {
|
||||
r.parsers[strings.ToLower(ext)] = p
|
||||
}
|
||||
}
|
||||
|
||||
// Parse dispatches to the appropriate parser based on file extension.
|
||||
func (r *Registry) Parse(path string) (string, error) {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
p, ok := r.parsers[ext]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported file type: %s", ext)
|
||||
}
|
||||
return p.ParseToMarkdown(path)
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ledongthuc/pdf"
|
||||
)
|
||||
|
||||
// PDFParser extracts plain text from PDF files using ledongthuc/pdf (open source).
|
||||
type PDFParser struct{}
|
||||
|
||||
func (p *PDFParser) SupportedExtensions() []string {
|
||||
return []string{".pdf"}
|
||||
}
|
||||
|
||||
func (p *PDFParser) ParseToMarkdown(path string) (string, error) {
|
||||
f, r, err := pdf.Open(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open pdf: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
totalPages := r.NumPage()
|
||||
|
||||
for i := 1; i <= totalPages; i++ {
|
||||
page := r.Page(i)
|
||||
if page.V.IsNull() {
|
||||
continue
|
||||
}
|
||||
text, err := page.GetPlainText(nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(text)
|
||||
if content != "" {
|
||||
sb.WriteString(fmt.Sprintf("## Page %d\n\n%s\n\n", i, content))
|
||||
}
|
||||
}
|
||||
|
||||
if sb.Len() == 0 {
|
||||
// Fallback: try reading entire content at once
|
||||
content, err := readPDFPlainText(path)
|
||||
if err == nil && content != "" {
|
||||
return content, nil
|
||||
}
|
||||
return "", fmt.Errorf("no text content extracted from PDF")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func readPDFPlainText(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
stat, _ := f.Stat()
|
||||
reader, err := pdf.NewReader(f, stat.Size())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var sb strings.Builder
|
||||
for i := 1; i <= reader.NumPage(); i++ {
|
||||
page := reader.Page(i)
|
||||
if page.V.IsNull() {
|
||||
continue
|
||||
}
|
||||
text, _ := page.GetPlainText(nil)
|
||||
sb.WriteString(text)
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/wailsapp/wails/v3/pkg/application"
|
||||
)
|
||||
|
||||
// Job represents a file parsing task.
|
||||
type Job struct {
|
||||
FilePath string
|
||||
FileID string
|
||||
ProjectID string
|
||||
}
|
||||
|
||||
// Result is the outcome of a parsing job.
|
||||
type Result struct {
|
||||
Job Job
|
||||
Content string
|
||||
Err error
|
||||
}
|
||||
|
||||
// ProcessingQueue manages concurrent file parsing with a worker pool.
|
||||
type ProcessingQueue struct {
|
||||
registry *Registry
|
||||
jobs chan Job
|
||||
results chan Result
|
||||
concurrency int
|
||||
wg sync.WaitGroup
|
||||
onComplete func(Result)
|
||||
}
|
||||
|
||||
// NewProcessingQueue creates a queue with the given concurrency limit.
|
||||
func NewProcessingQueue(registry *Registry, concurrency int, onComplete func(Result)) *ProcessingQueue {
|
||||
q := &ProcessingQueue{
|
||||
registry: registry,
|
||||
jobs: make(chan Job, 100),
|
||||
results: make(chan Result, 100),
|
||||
concurrency: concurrency,
|
||||
onComplete: onComplete,
|
||||
}
|
||||
q.start()
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *ProcessingQueue) start() {
|
||||
// Start workers
|
||||
for i := 0; i < q.concurrency; i++ {
|
||||
go q.worker(i)
|
||||
}
|
||||
// Start result collector
|
||||
go func() {
|
||||
for result := range q.results {
|
||||
if q.onComplete != nil {
|
||||
q.onComplete(result)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (q *ProcessingQueue) worker(id int) {
|
||||
for job := range q.jobs {
|
||||
slog.Info("parser worker processing", "worker", id, "file", job.FilePath)
|
||||
content, err := q.registry.Parse(job.FilePath)
|
||||
if err != nil {
|
||||
slog.Error("parse failed", "file", job.FilePath, "err", err)
|
||||
}
|
||||
q.results <- Result{Job: job, Content: content, Err: err}
|
||||
q.wg.Done()
|
||||
}
|
||||
}
|
||||
|
||||
// Submit adds a job to the processing queue.
|
||||
func (q *ProcessingQueue) Submit(job Job) {
|
||||
q.wg.Add(1)
|
||||
q.jobs <- job
|
||||
}
|
||||
|
||||
// Wait blocks until all queued jobs are complete.
|
||||
func (q *ProcessingQueue) Wait() {
|
||||
q.wg.Wait()
|
||||
}
|
||||
|
||||
// Close shuts down the queue.
|
||||
func (q *ProcessingQueue) Close() {
|
||||
close(q.jobs)
|
||||
q.wg.Wait()
|
||||
close(q.results)
|
||||
}
|
||||
|
||||
// ParseService wraps the parsing pipeline for Wails binding.
|
||||
type ParseService struct {
|
||||
registry *Registry
|
||||
queue *ProcessingQueue
|
||||
}
|
||||
|
||||
// NewParseService creates a new parse service.
|
||||
func NewParseService() *ParseService {
|
||||
return &ParseService{
|
||||
registry: NewRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFile synchronously parses a single file. For Wails binding.
|
||||
func (s *ParseService) ParseFile(path string) (string, error) {
|
||||
content, err := s.registry.Parse(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// GetSupportedTypes returns supported file extensions.
|
||||
func (s *ParseService) GetSupportedTypes() []string {
|
||||
return []string{".xlsx", ".xls", ".pdf", ".dwg", ".dxf", ".docx"}
|
||||
}
|
||||
|
||||
// ParseDeliveryStandard opens a file dialog to select a document, parses it, and returns the markdown.
|
||||
func (s *ParseService) ParseDeliveryStandard() (string, error) {
|
||||
dialog := application.Get().Dialog.OpenFile()
|
||||
dialog.SetTitle("选择交付标准文件 (Delivery Standard)")
|
||||
dialog.AddFilter("Documents", "*.pdf;*.xlsx;*.xls;*.docx")
|
||||
|
||||
path, err := dialog.PromptForSingleSelection()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open file dialog: %w", err)
|
||||
}
|
||||
if path == "" {
|
||||
return "", nil // user cancelled
|
||||
}
|
||||
return s.ParseFile(path)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// WordParser extracts text from .docx files
|
||||
type WordParser struct{}
|
||||
|
||||
// SupportedExtensions returns extensions this parser handles
|
||||
func (p *WordParser) SupportedExtensions() []string {
|
||||
return []string{".docx"}
|
||||
}
|
||||
|
||||
type wDocument struct {
|
||||
Body wBody `xml:"body"`
|
||||
}
|
||||
|
||||
type wBody struct {
|
||||
P []wP `xml:"p"`
|
||||
}
|
||||
|
||||
type wP struct {
|
||||
R []wR `xml:"r"`
|
||||
}
|
||||
|
||||
type wR struct {
|
||||
T []string `xml:"t"`
|
||||
}
|
||||
|
||||
// ParseToMarkdown extracts text from the word document and returns it as Markdown-like text
|
||||
func (p *WordParser) ParseToMarkdown(path string) (string, error) {
|
||||
r, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open docx as zip: %w", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
var docFile *zip.File
|
||||
for _, f := range r.File {
|
||||
if f.Name == "word/document.xml" {
|
||||
docFile = f
|
||||
break
|
||||
}
|
||||
}
|
||||
if docFile == nil {
|
||||
return "", fmt.Errorf("invalid docx file: word/document.xml not found")
|
||||
}
|
||||
|
||||
rc, err := docFile.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var doc wDocument
|
||||
if err := xml.Unmarshal(data, &doc); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var md string
|
||||
for _, paragraph := range doc.Body.P {
|
||||
var pText string
|
||||
for _, run := range paragraph.R {
|
||||
for _, text := range run.T {
|
||||
pText += text
|
||||
}
|
||||
}
|
||||
if pText != "" {
|
||||
md += pText + "\n\n"
|
||||
}
|
||||
}
|
||||
|
||||
return md, nil
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"engimind/internal/models"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// ProjectService manages multi-project lifecycle with isolated SQLite databases.
|
||||
type ProjectService struct {
|
||||
globalDB *gorm.DB
|
||||
projectDB *gorm.DB
|
||||
currentID string
|
||||
mu sync.Mutex
|
||||
baseDir string
|
||||
}
|
||||
|
||||
// NewProjectService creates a project service. Call SetGlobalDB before use.
|
||||
func NewProjectService() *ProjectService {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return &ProjectService{
|
||||
baseDir: filepath.Join(homeDir, ".engimind", "projects"),
|
||||
}
|
||||
}
|
||||
|
||||
// SetGlobalDB injects the global database reference.
|
||||
func (s *ProjectService) SetGlobalDB(db *gorm.DB) {
|
||||
s.globalDB = db
|
||||
}
|
||||
|
||||
// ListProjects returns all registered projects.
|
||||
func (s *ProjectService) ListProjects() ([]models.Project, error) {
|
||||
var projects []models.Project
|
||||
err := s.globalDB.Order("created_at DESC").Find(&projects).Error
|
||||
return projects, err
|
||||
}
|
||||
|
||||
// CreateProject creates a new project with its own SQLite database.
|
||||
func (s *ProjectService) CreateProject(name string) (*models.Project, error) {
|
||||
id := fmt.Sprintf("p-%d", time.Now().UnixMilli())
|
||||
projDir := filepath.Join(s.baseDir, id)
|
||||
if err := os.MkdirAll(projDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("create project dir: %w", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(projDir, "project.db")
|
||||
proj := models.Project{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Path: dbPath,
|
||||
}
|
||||
if err := s.globalDB.Create(&proj).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize project DB
|
||||
if err := s.openProjectDB(dbPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.currentID = id
|
||||
return &proj, nil
|
||||
}
|
||||
|
||||
// SwitchProject closes current project DB and opens a new one.
|
||||
func (s *ProjectService) SwitchProject(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.currentID == id && s.projectDB != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close current
|
||||
s.closeCurrentDB()
|
||||
|
||||
// Find project
|
||||
var proj models.Project
|
||||
if err := s.globalDB.First(&proj, "id = ?", id).Error; err != nil {
|
||||
return fmt.Errorf("project not found: %w", err)
|
||||
}
|
||||
|
||||
if err := s.openProjectDB(proj.Path); err != nil {
|
||||
return err
|
||||
}
|
||||
s.currentID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProject removes the project and its database files.
|
||||
func (s *ProjectService) DeleteProject(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.currentID == id {
|
||||
s.closeCurrentDB()
|
||||
}
|
||||
|
||||
var proj models.Project
|
||||
if err := s.globalDB.First(&proj, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove DB files
|
||||
projDir := filepath.Dir(proj.Path)
|
||||
os.RemoveAll(projDir)
|
||||
|
||||
return s.globalDB.Delete(&proj).Error
|
||||
}
|
||||
|
||||
// GetCurrentProjectID returns the active project ID.
|
||||
func (s *ProjectService) GetCurrentProjectID() string {
|
||||
return s.currentID
|
||||
}
|
||||
|
||||
// GetProjectDB returns the current project's database.
|
||||
func (s *ProjectService) GetProjectDB() *gorm.DB {
|
||||
return s.projectDB
|
||||
}
|
||||
|
||||
func (s *ProjectService) openProjectDB(dbPath string) error {
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("open project db: %w", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&models.SourceFile{},
|
||||
&models.ChatMessage{},
|
||||
&models.TemplateChapter{},
|
||||
&models.TextChunk{},
|
||||
); err != nil {
|
||||
return fmt.Errorf("auto migrate project db: %w", err)
|
||||
}
|
||||
|
||||
s.projectDB = db
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProjectService) closeCurrentDB() {
|
||||
if s.projectDB != nil {
|
||||
sqlDB, err := s.projectDB.DB()
|
||||
if err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
s.projectDB = nil
|
||||
s.currentID = ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmbeddingService calls Ollama or OpenAI-compatible APIs for embeddings.
|
||||
type EmbeddingService struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewEmbeddingService creates an embedding service.
|
||||
func NewEmbeddingService() *EmbeddingService {
|
||||
return &EmbeddingService{
|
||||
client: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// EmbeddingRequest is the request body for Ollama embedding API.
|
||||
type ollamaEmbedReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ollamaEmbedResp struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
// openAI-compatible embedding request
|
||||
type openAIEmbedReq struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type openAIEmbedResp struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// GetEmbedding generates an embedding vector for the given text.
|
||||
// provider: "ollama" or "openai" (compatible format)
|
||||
func (s *EmbeddingService) GetEmbedding(text, baseURL, model, apiKey, provider string) ([]float32, error) {
|
||||
switch provider {
|
||||
case "Ollama":
|
||||
return s.ollamaEmbed(text, baseURL, model)
|
||||
default:
|
||||
return s.openAIEmbed(text, baseURL, model, apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) ollamaEmbed(text, baseURL, model string) ([]float32, error) {
|
||||
body, _ := json.Marshal(ollamaEmbedReq{Model: model, Prompt: text})
|
||||
resp, err := s.client.Post(baseURL+"/api/embeddings", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama embed request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
var result ollamaEmbedResp
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse ollama response: %w", err)
|
||||
}
|
||||
if len(result.Embedding) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding returned")
|
||||
}
|
||||
return result.Embedding, nil
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) openAIEmbed(text, baseURL, model, apiKey string) ([]float32, error) {
|
||||
body, _ := json.Marshal(openAIEmbedReq{Model: model, Input: text})
|
||||
req, _ := http.NewRequest("POST", baseURL+"/v1/embeddings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai embed request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
var result openAIEmbedResp
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse openai response: %w", err)
|
||||
}
|
||||
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding returned")
|
||||
}
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// ChunkText splits text into overlapping chunks for vectorization.
|
||||
// chunkSize: target characters per chunk, overlap: characters of overlap.
|
||||
func ChunkText(text string, chunkSize, overlap int) []string {
|
||||
runes := []rune(text)
|
||||
if len(runes) <= chunkSize {
|
||||
return []string{text}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
start := 0
|
||||
for start < len(runes) {
|
||||
end := start + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunks = append(chunks, string(runes[start:end]))
|
||||
start += chunkSize - overlap
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"engimind/internal/models"
|
||||
)
|
||||
|
||||
// ContextChunk is a search result with source metadata.
|
||||
type ContextChunk struct {
|
||||
Text string `json:"text"`
|
||||
SourceID string `json:"sourceId"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
// RAGService orchestrates embedding + vector search for retrieval.
|
||||
type RAGService struct {
|
||||
embedding *EmbeddingService
|
||||
store *QdrantStore
|
||||
}
|
||||
|
||||
// NewRAGService creates a RAG service.
|
||||
func NewRAGService(embedding *EmbeddingService, store *QdrantStore) *RAGService {
|
||||
return &RAGService{embedding: embedding, store: store}
|
||||
}
|
||||
|
||||
// CollectionName returns the Qdrant collection name for a project.
|
||||
func CollectionName(projectID string) string {
|
||||
return fmt.Sprintf("engimind_%s", projectID)
|
||||
}
|
||||
|
||||
// IndexDocument chunks and indexes a parsed document.
|
||||
func (s *RAGService) IndexDocument(ctx context.Context, projectID string, source models.SourceFile, content string, embeddingCfg EmbeddingConfig) error {
|
||||
colName := CollectionName(projectID)
|
||||
if err := s.store.EnsureCollection(ctx, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
textChunks := ChunkText(content, 500, 50)
|
||||
var chunks []Chunk
|
||||
for i, text := range textChunks {
|
||||
vec, err := s.embedding.GetEmbedding(
|
||||
text, embeddingCfg.BaseURL, embeddingCfg.Model,
|
||||
embeddingCfg.APIKey, embeddingCfg.Provider,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embed chunk %d: %w", i, err)
|
||||
}
|
||||
chunks = append(chunks, Chunk{
|
||||
ID: fmt.Sprintf("%s-chunk-%d", source.ID, i),
|
||||
SourceID: source.ID,
|
||||
Text: text,
|
||||
Vector: vec,
|
||||
})
|
||||
}
|
||||
|
||||
return s.store.Insert(ctx, colName, chunks)
|
||||
}
|
||||
|
||||
// SearchContext retrieves relevant text chunks for a query.
|
||||
func (s *RAGService) SearchContext(ctx context.Context, projectID, question string, topK int, embeddingCfg EmbeddingConfig) ([]ContextChunk, error) {
|
||||
queryVec, err := s.embedding.GetEmbedding(
|
||||
question, embeddingCfg.BaseURL, embeddingCfg.Model,
|
||||
embeddingCfg.APIKey, embeddingCfg.Provider,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
colName := CollectionName(projectID)
|
||||
results, err := s.store.Search(ctx, colName, queryVec, uint64(topK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contextChunks := make([]ContextChunk, len(results))
|
||||
for i, r := range results {
|
||||
contextChunks[i] = ContextChunk{
|
||||
Text: r.Text,
|
||||
SourceID: r.SourceID,
|
||||
}
|
||||
}
|
||||
return contextChunks, nil
|
||||
}
|
||||
|
||||
// EmbeddingConfig holds the config needed to call an embedding API.
|
||||
type EmbeddingConfig struct {
|
||||
BaseURL string
|
||||
Model string
|
||||
APIKey string
|
||||
Provider string
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
pb "github.com/qdrant/go-client/qdrant"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// Chunk is a text segment with its vector.
|
||||
type Chunk struct {
|
||||
ID string
|
||||
SourceID string
|
||||
Text string
|
||||
Vector []float32
|
||||
}
|
||||
|
||||
// QdrantStore implements vector storage via remote Qdrant gRPC.
|
||||
type QdrantStore struct {
|
||||
conn *grpc.ClientConn
|
||||
points pb.PointsClient
|
||||
collection pb.CollectionsClient
|
||||
dimension uint64
|
||||
}
|
||||
|
||||
// NewQdrantStore connects to a Qdrant instance.
|
||||
func NewQdrantStore(endpoint string, dimension uint64) (*QdrantStore, error) {
|
||||
conn, err := grpc.NewClient(endpoint, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect qdrant: %w", err)
|
||||
}
|
||||
return &QdrantStore{
|
||||
conn: conn,
|
||||
points: pb.NewPointsClient(conn),
|
||||
collection: pb.NewCollectionsClient(conn),
|
||||
dimension: dimension,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureCollection creates a collection if it doesn't exist.
|
||||
func (s *QdrantStore) EnsureCollection(ctx context.Context, name string) error {
|
||||
_, err := s.collection.Get(ctx, &pb.GetCollectionInfoRequest{CollectionName: name})
|
||||
if err == nil {
|
||||
return nil // already exists
|
||||
}
|
||||
|
||||
_, err = s.collection.Create(ctx, &pb.CreateCollection{
|
||||
CollectionName: name,
|
||||
VectorsConfig: &pb.VectorsConfig{
|
||||
Config: &pb.VectorsConfig_Params{
|
||||
Params: &pb.VectorParams{
|
||||
Size: s.dimension,
|
||||
Distance: pb.Distance_Cosine,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create collection %s: %w", name, err)
|
||||
}
|
||||
slog.Info("created qdrant collection", "name", name, "dim", s.dimension)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert upserts chunks into the specified collection.
|
||||
func (s *QdrantStore) Insert(ctx context.Context, collectionName string, chunks []Chunk) error {
|
||||
points := make([]*pb.PointStruct, len(chunks))
|
||||
for i, c := range chunks {
|
||||
points[i] = &pb.PointStruct{
|
||||
Id: &pb.PointId{
|
||||
PointIdOptions: &pb.PointId_Uuid{Uuid: c.ID},
|
||||
},
|
||||
Vectors: &pb.Vectors{
|
||||
VectorsOptions: &pb.Vectors_Vector{
|
||||
Vector: &pb.Vector{Data: c.Vector},
|
||||
},
|
||||
},
|
||||
Payload: map[string]*pb.Value{
|
||||
"text": {Kind: &pb.Value_StringValue{StringValue: c.Text}},
|
||||
"source_id": {Kind: &pb.Value_StringValue{StringValue: c.SourceID}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.points.Upsert(ctx, &pb.UpsertPoints{
|
||||
CollectionName: collectionName,
|
||||
Points: points,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Search performs KNN search and returns top-k results.
|
||||
func (s *QdrantStore) Search(ctx context.Context, collectionName string, queryVec []float32, topK uint64) ([]Chunk, error) {
|
||||
resp, err := s.points.Search(ctx, &pb.SearchPoints{
|
||||
CollectionName: collectionName,
|
||||
Vector: queryVec,
|
||||
Limit: topK,
|
||||
WithPayload: &pb.WithPayloadSelector{SelectorOptions: &pb.WithPayloadSelector_Enable{Enable: true}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qdrant search: %w", err)
|
||||
}
|
||||
|
||||
results := make([]Chunk, 0, len(resp.Result))
|
||||
for _, hit := range resp.Result {
|
||||
text := ""
|
||||
sourceID := ""
|
||||
if v, ok := hit.Payload["text"]; ok {
|
||||
text = v.GetStringValue()
|
||||
}
|
||||
if v, ok := hit.Payload["source_id"]; ok {
|
||||
sourceID = v.GetStringValue()
|
||||
}
|
||||
results = append(results, Chunk{
|
||||
ID: hit.Id.GetUuid(),
|
||||
SourceID: sourceID,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// DeleteCollection removes a collection.
|
||||
func (s *QdrantStore) DeleteCollection(ctx context.Context, name string) error {
|
||||
_, err := s.collection.Delete(ctx, &pb.DeleteCollection{CollectionName: name})
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the gRPC connection.
|
||||
func (s *QdrantStore) Close() {
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnection verifies the Qdrant server is reachable.
|
||||
func (s *QdrantStore) TestConnection(ctx context.Context) (bool, error) {
|
||||
_, err := s.collection.List(ctx, &pb.ListCollectionsRequest{})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user