91 lines
2.2 KiB
Go
91 lines
2.2 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/csv"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"AI-Expert-Sidebar/internal/database"
|
|
"AI-Expert-Sidebar/internal/models"
|
|
)
|
|
|
|
// ImportResult summarises the outcome of a CSV import.
|
|
type ImportResult struct {
|
|
Imported int `json:"imported"`
|
|
Skipped int `json:"skipped"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// ImportCSV reads a CSV file and inserts records into the active knowledge library.
|
|
//
|
|
// Required columns (case-insensitive): keyword, question, answer
|
|
// Optional column: category (defaults to "通用")
|
|
//
|
|
// The first row must be the header.
|
|
func ImportCSV(filePath string) ImportResult {
|
|
f, err := os.Open(filePath)
|
|
if err != nil {
|
|
return ImportResult{Error: fmt.Sprintf("无法打开文件: %v", err)}
|
|
}
|
|
defer f.Close()
|
|
|
|
db := database.Get()
|
|
if db == nil {
|
|
return ImportResult{Error: "知识库未初始化"}
|
|
}
|
|
|
|
r := csv.NewReader(f)
|
|
r.TrimLeadingSpace = true
|
|
r.LazyQuotes = true
|
|
|
|
// Read and normalise header
|
|
header, err := r.Read()
|
|
if err != nil {
|
|
return ImportResult{Error: fmt.Sprintf("读取表头失败: %v", err)}
|
|
}
|
|
colIdx := make(map[string]int)
|
|
for i, h := range header {
|
|
colIdx[strings.ToLower(strings.TrimSpace(h))] = i
|
|
}
|
|
for _, required := range []string{"keyword", "question", "answer"} {
|
|
if _, ok := colIdx[required]; !ok {
|
|
return ImportResult{Error: fmt.Sprintf("CSV 缺少必需列: %q (需要: keyword, question, answer)", required)}
|
|
}
|
|
}
|
|
catIdx, hasCat := colIdx["category"]
|
|
|
|
var imported, skipped int
|
|
for {
|
|
row, err := r.Read()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
skipped++
|
|
continue
|
|
}
|
|
keyword := strings.TrimSpace(row[colIdx["keyword"]])
|
|
question := strings.TrimSpace(row[colIdx["question"]])
|
|
answer := strings.TrimSpace(row[colIdx["answer"]])
|
|
if keyword == "" || question == "" || answer == "" {
|
|
skipped++
|
|
continue
|
|
}
|
|
cat := "通用"
|
|
if hasCat && catIdx < len(row) {
|
|
if v := strings.TrimSpace(row[catIdx]); v != "" {
|
|
cat = v
|
|
}
|
|
}
|
|
entry := models.Entry{Keyword: keyword, Question: question, Answer: answer, Category: cat}
|
|
if err := db.Create(&entry).Error; err != nil {
|
|
skipped++
|
|
} else {
|
|
imported++
|
|
}
|
|
}
|
|
return ImportResult{Imported: imported, Skipped: skipped}
|
|
}
|