1074 lines
41 KiB
Python
1074 lines
41 KiB
Python
"""
|
|
Engimind Python Backend — FastAPI application
|
|
All API routes matching the original Go/Wails backend functionality.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
from database import db
|
|
from models import (
|
|
LLMProvider, VectorDBConfig, EmbeddingModelConfig, Project,
|
|
SourceFile, ChatMessage, TemplateChapter, TextChunk, DeliveryStandard,
|
|
LLMProviderSchema, VectorDBConfigSchema, EmbeddingConfigSchema,
|
|
ProjectSchema, SourceFileSchema, ChatMessageSchema,
|
|
TemplateChapterSchema, DeliveryStandardSchema,
|
|
)
|
|
from llm_client import llm_client
|
|
from vector_service import (
|
|
embedding_service, vector_store, rag_service,
|
|
chunk_text, search_text_chunks_keyword,
|
|
)
|
|
from parsers.registry import (
|
|
parse_file, detect_file_type, categorize_file, get_file_size,
|
|
SUPPORTED_EXTENSIONS,
|
|
)
|
|
from parsers.excel_parser import parse_excel_to_chunks, pre_parse_excel
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
logger = logging.getLogger("engimind")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
db.init_global()
|
|
logger.info("Engimind backend started")
|
|
# Try connecting to Qdrant if configured
|
|
with db.global_session() as s:
|
|
vdb = s.query(VectorDBConfig).first()
|
|
if vdb and vdb.endpoint:
|
|
try:
|
|
vector_store.connect(vdb.endpoint)
|
|
except Exception as e:
|
|
logger.warning("Qdrant connection failed at startup: %s", e)
|
|
# Auto-open the most recent project so API calls work immediately
|
|
with db.global_session() as s:
|
|
latest = s.query(Project).order_by(Project.updated_at.desc()).first()
|
|
if latest and latest.path:
|
|
try:
|
|
db.open_project(latest.id, latest.path)
|
|
logger.info("Auto-opened project: %s (%s)", latest.name, latest.id)
|
|
except Exception as e:
|
|
logger.warning("Failed to auto-open project: %s", e)
|
|
yield
|
|
await llm_client.close()
|
|
await embedding_service.close()
|
|
vector_store.disconnect()
|
|
logger.info("Engimind backend stopped")
|
|
|
|
|
|
app = FastAPI(title="Engimind", version="2.0.0", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Health
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Config: LLM Providers
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/providers", response_model=list[LLMProviderSchema])
|
|
async def get_all_providers():
|
|
with db.global_session() as s:
|
|
rows = s.query(LLMProvider).all()
|
|
return [LLMProviderSchema.from_orm_obj(r) for r in rows]
|
|
|
|
|
|
@app.post("/api/providers")
|
|
async def save_provider(p: LLMProviderSchema):
|
|
with db.global_session() as s:
|
|
obj = s.query(LLMProvider).filter_by(id=p.id).first()
|
|
if obj:
|
|
obj.name, obj.provider, obj.base_url = p.name, p.provider, p.url
|
|
obj.api_key, obj.model_id, obj.enabled = p.key, p.model, p.enabled
|
|
else:
|
|
obj = LLMProvider(id=p.id, name=p.name, provider=p.provider,
|
|
base_url=p.url, api_key=p.key, model_id=p.model,
|
|
enabled=p.enabled)
|
|
s.add(obj)
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.delete("/api/providers/{provider_id}")
|
|
async def delete_provider(provider_id: str):
|
|
with db.global_session() as s:
|
|
s.query(LLMProvider).filter_by(id=provider_id).delete()
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Config: VectorDB
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/vector-db", response_model=VectorDBConfigSchema)
|
|
async def get_vector_db():
|
|
with db.global_session() as s:
|
|
cfg = s.query(VectorDBConfig).first()
|
|
if not cfg:
|
|
return VectorDBConfigSchema()
|
|
return VectorDBConfigSchema(endpoint=cfg.endpoint, apiKey=cfg.api_key, status=cfg.status)
|
|
|
|
|
|
@app.post("/api/vector-db")
|
|
async def save_vector_db(c: VectorDBConfigSchema):
|
|
with db.global_session() as s:
|
|
cfg = s.query(VectorDBConfig).first()
|
|
if cfg:
|
|
cfg.endpoint, cfg.api_key, cfg.status = c.endpoint, c.apiKey, c.status
|
|
else:
|
|
s.add(VectorDBConfig(id=1, endpoint=c.endpoint, api_key=c.apiKey, status=c.status))
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/api/vector-db/test")
|
|
async def test_vector_db(body: dict):
|
|
endpoint = body.get("endpoint", "")
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
client = QdrantClient(url=endpoint, timeout=5)
|
|
client.get_collections()
|
|
client.close()
|
|
return {"ok": True}
|
|
except Exception as e:
|
|
raise HTTPException(400, f"连接失败: {e}")
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Config: Embedding
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/embedding", response_model=EmbeddingConfigSchema)
|
|
async def get_embedding_config():
|
|
with db.global_session() as s:
|
|
cfg = s.query(EmbeddingModelConfig).first()
|
|
if not cfg:
|
|
return EmbeddingConfigSchema()
|
|
return EmbeddingConfigSchema.from_orm_obj(cfg)
|
|
|
|
|
|
@app.post("/api/embedding")
|
|
async def save_embedding_config(c: EmbeddingConfigSchema):
|
|
with db.global_session() as s:
|
|
cfg = s.query(EmbeddingModelConfig).first()
|
|
if cfg:
|
|
cfg.provider, cfg.base_url, cfg.api_key = c.provider, c.url, c.key
|
|
cfg.model_id, cfg.enabled = c.model, c.enabled
|
|
cfg.dimensions = c.dimensions
|
|
else:
|
|
s.add(EmbeddingModelConfig(id=1, provider=c.provider, base_url=c.url,
|
|
api_key=c.key, model_id=c.model, enabled=c.enabled,
|
|
dimensions=c.dimensions))
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/api/embedding/test")
|
|
async def test_embedding(body: dict):
|
|
try:
|
|
vec = await embedding_service.get_embedding(
|
|
"hello", body["url"], body["model"], body.get("key", ""), body["provider"])
|
|
if not vec:
|
|
raise HTTPException(400, "返回的向量为空")
|
|
return {"ok": True, "dim": len(vec)}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(400, f"Embedding 测试失败: {e}")
|
|
|
|
|
|
@app.post("/api/llm/test")
|
|
async def test_llm(body: dict):
|
|
provider = body.get("provider", "")
|
|
base_url = body.get("url", "").rstrip("/")
|
|
api_key = body.get("key", "")
|
|
|
|
import httpx
|
|
if provider == "Ollama":
|
|
url = base_url + "/api/tags"
|
|
elif provider in ("DeepSeek", "OpenAI"):
|
|
url = base_url + "/models"
|
|
else:
|
|
url = base_url + "/v1/models"
|
|
|
|
headers = {}
|
|
if api_key and provider != "Ollama":
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=8) as client:
|
|
resp = await client.get(url, headers=headers)
|
|
if 200 <= resp.status_code < 400:
|
|
return {"ok": True}
|
|
if resp.status_code == 401:
|
|
raise HTTPException(400, "API Key 无效")
|
|
raise HTTPException(400, f"HTTP {resp.status_code}")
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(400, f"网络连通性异常: {e}")
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Projects
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/projects", response_model=list[ProjectSchema])
|
|
async def list_projects():
|
|
with db.global_session() as s:
|
|
return [ProjectSchema.model_validate(p) for p in
|
|
s.query(Project).order_by(Project.created_at.desc()).all()]
|
|
|
|
|
|
@app.post("/api/projects", response_model=ProjectSchema)
|
|
async def create_project(body: dict):
|
|
name = body.get("name", "未命名项目")
|
|
pid = f"p-{int(time.time() * 1000)}"
|
|
proj_dir = db.projects_base_dir() / pid
|
|
proj_dir.mkdir(parents=True, exist_ok=True)
|
|
db_path = str(proj_dir / "project.db")
|
|
|
|
with db.global_session() as s:
|
|
proj = Project(id=pid, name=name, path=db_path)
|
|
s.add(proj)
|
|
s.commit()
|
|
|
|
db.open_project(pid, db_path)
|
|
return ProjectSchema(id=pid, name=name)
|
|
|
|
|
|
@app.post("/api/projects/{project_id}/switch")
|
|
async def switch_project(project_id: str):
|
|
with db.global_session() as s:
|
|
proj = s.query(Project).filter_by(id=project_id).first()
|
|
if not proj:
|
|
raise HTTPException(404, "项目不存在")
|
|
db.open_project(proj.id, proj.path)
|
|
return {"ok": True}
|
|
|
|
|
|
@app.delete("/api/projects/{project_id}")
|
|
async def delete_project(project_id: str):
|
|
if db.current_project_id == project_id:
|
|
db.close_project()
|
|
with db.global_session() as s:
|
|
proj = s.query(Project).filter_by(id=project_id).first()
|
|
if proj:
|
|
proj_dir = os.path.dirname(proj.path)
|
|
if os.path.isdir(proj_dir):
|
|
shutil.rmtree(proj_dir, ignore_errors=True)
|
|
s.delete(proj)
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Materials
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/materials", response_model=list[SourceFileSchema])
|
|
async def get_project_files():
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
return []
|
|
with session as s:
|
|
files = s.query(SourceFile).filter_by(project_id=db.current_project_id)\
|
|
.order_by(SourceFile.created_at.desc()).all()
|
|
return [SourceFileSchema.from_orm_obj(f) for f in files]
|
|
|
|
|
|
@app.get("/api/materials/{file_id}/content")
|
|
async def get_material_content(file_id: str):
|
|
"""Return file content for online preview.
|
|
|
|
- Excel: structured JSON with sheets -> rows (for table rendering)
|
|
- Others: markdown text
|
|
"""
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "项目数据库未就绪")
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if not sf:
|
|
raise HTTPException(404, "素材未找到")
|
|
file_path = sf.file_path
|
|
file_type = sf.type
|
|
cached_content = sf.parsed_content or ""
|
|
|
|
if not file_path or not os.path.isfile(file_path):
|
|
raise HTTPException(400, "文件不存在")
|
|
|
|
if file_type == "excel":
|
|
# Return structured table data for rich rendering
|
|
from parsers.excel_parser import _iter_sheets, _cell_str
|
|
sheets = []
|
|
for sheet_name, grid, _ in _iter_sheets(file_path):
|
|
if not grid:
|
|
continue
|
|
rows = []
|
|
for row in grid:
|
|
rows.append([_cell_str(c) for c in row])
|
|
sheets.append({"name": sheet_name, "rows": rows})
|
|
return {"type": "excel", "sheets": sheets}
|
|
else:
|
|
# For non-excel: return cached markdown or re-parse
|
|
if cached_content:
|
|
return {"type": file_type, "content": cached_content}
|
|
result = parse_file(file_path)
|
|
return {"type": file_type, "content": result.get("markdown", "")}
|
|
|
|
|
|
class UploadRequest(BaseModel):
|
|
filePaths: list[str]
|
|
|
|
|
|
@app.post("/api/materials/upload", response_model=list[SourceFileSchema])
|
|
async def upload_materials(req: UploadRequest):
|
|
"""Upload materials by local file paths (Electron sends paths from dialog)."""
|
|
pid = db.current_project_id
|
|
if not pid:
|
|
raise HTTPException(400, "请先创建或选择一个项目")
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "项目数据库未就绪")
|
|
|
|
files_dir = db.projects_base_dir() / pid / "files"
|
|
files_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
results = []
|
|
for i, src_path in enumerate(req.filePaths):
|
|
if not os.path.isfile(src_path):
|
|
continue
|
|
file_name = os.path.basename(src_path)
|
|
file_id = f"f-{int(time.time() * 1000)}-{i}"
|
|
file_type = detect_file_type(src_path)
|
|
file_size = get_file_size(src_path)
|
|
|
|
dst_path = str(files_dir / f"{file_id}_{file_name}")
|
|
shutil.copy2(src_path, dst_path)
|
|
|
|
with session as s:
|
|
# Remove duplicates
|
|
existing = s.query(SourceFile).filter_by(project_id=pid, name=file_name).all()
|
|
for old in existing:
|
|
if old.file_path and os.path.isfile(old.file_path):
|
|
os.remove(old.file_path)
|
|
s.query(TextChunk).filter_by(source_id=old.id).delete()
|
|
s.delete(old)
|
|
if vector_store.connected:
|
|
vector_store.delete_by_source(pid, old.id)
|
|
|
|
sf = SourceFile(id=file_id, project_id=pid, name=file_name,
|
|
type=file_type, category=categorize_file(file_type),
|
|
file_path=dst_path, size=file_size, vector_status="pending")
|
|
s.add(sf)
|
|
s.commit()
|
|
|
|
results.append(SourceFileSchema(id=file_id, name=file_name, type=file_type,
|
|
category=categorize_file(file_type),
|
|
size=file_size, vectorStatus="pending"))
|
|
|
|
# Parse and index in background
|
|
asyncio.create_task(_parse_and_index(pid, file_id, file_name, file_type, dst_path))
|
|
|
|
return results
|
|
|
|
|
|
@app.delete("/api/materials/{file_id}")
|
|
async def delete_material(file_id: str):
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "项目数据库未就绪")
|
|
pid = db.current_project_id
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if not sf:
|
|
raise HTTPException(404, "素材未找到")
|
|
if sf.file_path and os.path.isfile(sf.file_path):
|
|
os.remove(sf.file_path)
|
|
s.query(TextChunk).filter_by(source_id=file_id).delete()
|
|
s.delete(sf)
|
|
s.commit()
|
|
if vector_store.connected and pid:
|
|
vector_store.delete_by_source(pid, file_id)
|
|
return {"ok": True}
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Chat
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/chat/messages", response_model=list[ChatMessageSchema])
|
|
async def get_chat_messages():
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
return []
|
|
with session as s:
|
|
msgs = s.query(ChatMessage).filter_by(project_id=db.current_project_id)\
|
|
.order_by(ChatMessage.created_at).all()
|
|
return [ChatMessageSchema.from_orm_obj(m) for m in msgs]
|
|
|
|
|
|
class SaveMessageRequest(BaseModel):
|
|
role: str
|
|
content: str
|
|
sources: str = ""
|
|
citations: str = ""
|
|
|
|
|
|
@app.post("/api/chat/messages")
|
|
async def save_chat_message(req: SaveMessageRequest):
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
raise HTTPException(400, "无活动项目")
|
|
with session as s:
|
|
s.add(ChatMessage(project_id=db.current_project_id, role=req.role,
|
|
content=req.content, sources=req.sources, citations=req.citations))
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.delete("/api/chat/messages")
|
|
async def clear_chat_messages():
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
return {"ok": True}
|
|
with session as s:
|
|
s.query(ChatMessage).filter_by(project_id=db.current_project_id).delete()
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
class SendMessageRequest(BaseModel):
|
|
content: str
|
|
selectedFileIds: list[str] = []
|
|
modelId: str
|
|
|
|
|
|
_CHAT_SYSTEM_PROMPT = (
|
|
"你是一位专业的工程技术助手。你只能从提供的 Context 素材中提取数值和信息来回答问题。\n"
|
|
"重要约束:\n"
|
|
"1. 如果 Context 中没有与问题相关的数据,必须诚实回答『在当前素材中未找到相关数据』,严禁编造数值或结论。\n"
|
|
"2. 绝对不要联网搜索或自行猜测数据。\n"
|
|
"3. 素材中的数据路径可能与用户提问的词序不同。例如素材写『湿地 > 国家所有』而用户问『国家所有湿地』,"
|
|
"你必须识别出这是同一个属性并正确引用数值。\n"
|
|
"4. 如果素材信息不完整,应指出缺失部分,而非自行填补。\n"
|
|
"5. 引用来源时使用 [N] 标注。"
|
|
)
|
|
|
|
|
|
@app.post("/api/chat/send")
|
|
async def send_message(req: SendMessageRequest):
|
|
"""Non-streaming chat with RAG context."""
|
|
provider = _get_provider(req.modelId)
|
|
context_text = _search_material_context(req.content, req.selectedFileIds, 5)
|
|
|
|
messages = [
|
|
{"role": "system", "content": _CHAT_SYSTEM_PROMPT},
|
|
]
|
|
if context_text:
|
|
messages.append({"role": "user", "content": f"参考以下工程素材:\n\n{context_text}\n\n---\n\n用户提问:{req.content}"})
|
|
else:
|
|
messages.append({"role": "user", "content": req.content})
|
|
|
|
result = await llm_client.complete(provider["url"], provider["key"], provider["model"], messages)
|
|
if result.get("choices"):
|
|
return {"content": result["choices"][0]["message"]["content"]}
|
|
raise HTTPException(500, "模型无响应")
|
|
|
|
|
|
@app.post("/api/chat/stream")
|
|
async def stream_message(req: SendMessageRequest):
|
|
"""Streaming chat with RAG context via SSE."""
|
|
provider = _get_provider(req.modelId)
|
|
context_text = _search_material_context(req.content, req.selectedFileIds, 5)
|
|
|
|
messages = [
|
|
{"role": "system", "content": _CHAT_SYSTEM_PROMPT},
|
|
]
|
|
if context_text:
|
|
messages.append({"role": "user", "content": f"参考以下工程素材:\n\n{context_text}\n\n---\n\n用户提问:{req.content}"})
|
|
else:
|
|
messages.append({"role": "user", "content": req.content})
|
|
|
|
async def event_gen():
|
|
import json
|
|
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
|
|
yield {"data": json.dumps(chunk, ensure_ascii=False)}
|
|
|
|
return EventSourceResponse(event_gen())
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Template / Delivery Standard
|
|
# ══════════════════════════════════════════════
|
|
|
|
@app.get("/api/template/chapters", response_model=list[TemplateChapterSchema])
|
|
async def get_template_chapters():
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
return []
|
|
with session as s:
|
|
chapters = s.query(TemplateChapter).filter_by(project_id=db.current_project_id)\
|
|
.order_by(TemplateChapter.sort_order).all()
|
|
return [TemplateChapterSchema.from_orm_obj(c) for c in chapters]
|
|
|
|
|
|
@app.post("/api/template/chapters")
|
|
async def save_template_chapters(chapters: list[TemplateChapterSchema]):
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
raise HTTPException(400, "无活动项目")
|
|
with session as s:
|
|
s.query(TemplateChapter).filter_by(project_id=db.current_project_id).delete()
|
|
for i, c in enumerate(chapters):
|
|
s.add(TemplateChapter(
|
|
id=c.id, project_id=db.current_project_id,
|
|
title=c.title, status=c.status, progress=c.progress,
|
|
content=c.content, sort_order=i,
|
|
))
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.delete("/api/template/chapters/{chapter_id}")
|
|
async def delete_template_chapter(chapter_id: str):
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "无活动项目")
|
|
with session as s:
|
|
s.query(TemplateChapter).filter_by(id=chapter_id).delete()
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/api/template/extract-directory")
|
|
async def stream_template_directory(body: dict):
|
|
"""Stream LLM directory extraction from delivery standard text."""
|
|
content = body.get("content", "")
|
|
model_id = body.get("modelId", "")
|
|
provider = _get_provider(model_id)
|
|
|
|
prompt = (
|
|
"你是一个工程标准的目录解析助手。请从下面提供的交付标准文本中提取完整的章节目录结构(包括主章节和小节),并以 JSON 数组的格式返回。\n\n"
|
|
"### 要求:\n"
|
|
"1. 只返回 JSON 数组,不包含其他废话或者回答前缀。\n"
|
|
"2. 必须提取所有层级:主章节和其下的小节(子章节)。\n"
|
|
"3. 输出格式必须严格符合:\n"
|
|
'[{"id": "chapter-1", "title": "1 原材料进场检验", "content": "简要描述(可选)", '
|
|
'"children": [{"id": "chapter-1-1", "title": "1.1 钢材检验", "content": ""}]}]\n'
|
|
"4. 如果某个章节没有小节,children 可以为空数组 []。\n\n"
|
|
f"### 交付标准内容:\n{content}"
|
|
)
|
|
messages = [
|
|
{"role": "system", "content": "你是一个专业的结构化数据抽取工具。你只输出合法的 JSON,不要使用 Markdown 代码块包裹,也不要给出任何其他解释。"},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
async def event_gen():
|
|
import json
|
|
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
|
|
yield {"data": json.dumps(chunk, ensure_ascii=False)}
|
|
|
|
return EventSourceResponse(event_gen())
|
|
|
|
|
|
@app.get("/api/delivery-standard", response_model=Optional[DeliveryStandardSchema])
|
|
async def get_delivery_standard():
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
return None
|
|
with session as s:
|
|
ds = s.query(DeliveryStandard).filter_by(project_id=db.current_project_id).first()
|
|
if not ds:
|
|
return None
|
|
return DeliveryStandardSchema(fileName=ds.file_name, content=ds.content)
|
|
|
|
|
|
@app.post("/api/delivery-standard")
|
|
async def save_delivery_standard(body: DeliveryStandardSchema):
|
|
session = db.project_session()
|
|
if not session or not db.current_project_id:
|
|
raise HTTPException(400, "无活动项目")
|
|
with session as s:
|
|
s.query(DeliveryStandard).filter_by(project_id=db.current_project_id).delete()
|
|
s.add(DeliveryStandard(project_id=db.current_project_id,
|
|
file_name=body.fileName, content=body.content))
|
|
s.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/api/parse-file")
|
|
async def parse_file_endpoint(body: dict):
|
|
"""Parse a local file and return markdown. Used for delivery standard parsing."""
|
|
file_path = body.get("filePath", "")
|
|
if not file_path or not os.path.isfile(file_path):
|
|
raise HTTPException(400, "文件不存在")
|
|
result = parse_file(file_path)
|
|
return result
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Excel Pre-parse & Final Ingest
|
|
# ══════════════════════════════════════════════
|
|
|
|
class PreParseRequest(BaseModel):
|
|
fileId: str
|
|
startRow: Optional[int] = None # optional override for "refresh preview"
|
|
|
|
|
|
@app.post("/api/materials/pre-parse")
|
|
async def pre_parse_material(req: PreParseRequest):
|
|
"""Interface A: Pre-parse an Excel file and return preview data."""
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "项目数据库未就绪")
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=req.fileId).first()
|
|
if not sf:
|
|
raise HTTPException(404, "素材未找到")
|
|
file_path = sf.file_path
|
|
file_type = sf.type
|
|
|
|
if file_type != "excel":
|
|
raise HTTPException(400, "仅支持 Excel 文件的预解析")
|
|
if not file_path or not os.path.isfile(file_path):
|
|
raise HTTPException(400, "文件不存在")
|
|
|
|
result = pre_parse_excel(file_path, start_row=req.startRow)
|
|
return result
|
|
|
|
|
|
class FinalIngestRequest(BaseModel):
|
|
fileId: str
|
|
startRow: int
|
|
|
|
|
|
@app.post("/api/materials/ingest")
|
|
async def final_ingest_material(req: FinalIngestRequest):
|
|
"""Interface B: Final ingest — parse with confirmed start_row and vector index."""
|
|
pid = db.current_project_id
|
|
if not pid:
|
|
raise HTTPException(400, "请先选择一个项目")
|
|
session = db.project_session()
|
|
if not session:
|
|
raise HTTPException(400, "项目数据库未就绪")
|
|
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=req.fileId).first()
|
|
if not sf:
|
|
raise HTTPException(404, "素材未找到")
|
|
file_path = sf.file_path
|
|
file_name = sf.name
|
|
file_type = sf.type
|
|
|
|
if file_type != "excel":
|
|
raise HTTPException(400, "仅支持 Excel 文件")
|
|
if not file_path or not os.path.isfile(file_path):
|
|
raise HTTPException(400, "文件不存在")
|
|
|
|
# Launch background ingest task
|
|
asyncio.create_task(_ingest_excel(pid, req.fileId, file_name, file_path, req.startRow))
|
|
return {"ok": True, "message": f"开始入库,数据起始行: {req.startRow}"}
|
|
|
|
|
|
async def _ingest_excel(project_id: str, file_id: str, file_name: str,
|
|
file_path: str, start_row: int):
|
|
"""Background: parse Excel with confirmed start_row, clear old vectors, re-index."""
|
|
import json
|
|
|
|
session = db.project_session()
|
|
if not session:
|
|
return
|
|
|
|
def emit(status: str, step: str):
|
|
_broadcast_event(json.dumps({
|
|
"type": "material_status_update",
|
|
"fileId": file_id, "fileName": file_name,
|
|
"status": status, "step": step,
|
|
}, ensure_ascii=False))
|
|
|
|
try:
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "processing"
|
|
s.commit()
|
|
|
|
emit("processing", f"📄 开始重新解析「{file_name}」(起始行: {start_row})")
|
|
|
|
# Step 1: Parse with confirmed start_row
|
|
raw_chunks = parse_excel_to_chunks(file_path, start_row=start_row)
|
|
text_chunks = [c['content'] for c in raw_chunks]
|
|
metadata_list = [c['metadata'] for c in raw_chunks]
|
|
emit("processing", f"🔪 文本分块: 共 {len(text_chunks)} 个片段")
|
|
|
|
# Step 2: Clear old data
|
|
with session as s:
|
|
s.query(TextChunk).filter_by(source_id=file_id).delete()
|
|
s.commit()
|
|
if vector_store.connected:
|
|
vector_store.delete_by_source(project_id, file_id)
|
|
emit("processing", "🗑️ 已清除旧向量数据")
|
|
|
|
# Step 3: Store chunks in SQLite
|
|
with session as s:
|
|
for i, chunk in enumerate(text_chunks):
|
|
s.add(TextChunk(id=f"{file_id}-chunk-{i}", project_id=project_id,
|
|
source_id=file_id, content=chunk, chunk_idx=i))
|
|
s.commit()
|
|
|
|
# Step 4: Vector index
|
|
emb_cfg = _get_embedding_config()
|
|
if emb_cfg and vector_store.connected:
|
|
emit("processing", f"🔗 开始向量化索引 ({len(text_chunks)} 分块)...")
|
|
dim = emb_cfg.get("dimensions", 1024)
|
|
|
|
def on_progress(done: int, total: int):
|
|
emit("processing", f"🔗 向量化进度: {done}/{total}")
|
|
|
|
await rag_service.index_chunks(
|
|
project_id, file_id, text_chunks, emb_cfg,
|
|
dim=dim, metadata_list=metadata_list,
|
|
on_progress=on_progress,
|
|
)
|
|
emit("done", "✓ 素材完全就绪,向量索引成功")
|
|
else:
|
|
emit("done", f"✓ 素材就绪 ({len(text_chunks)} 分块)")
|
|
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "done"
|
|
s.commit()
|
|
|
|
except Exception as e:
|
|
logger.exception("Excel ingest failed for %s", file_name)
|
|
emit("error", f"入库失败: {e}")
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "error"
|
|
s.commit()
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Chapter Generation
|
|
# ══════════════════════════════════════════════
|
|
|
|
class GenerateChapterRequest(BaseModel):
|
|
chapterTitle: str
|
|
selectedFileIds: list[str] = []
|
|
modelId: str
|
|
|
|
|
|
@app.post("/api/chat/generate-chapter")
|
|
async def stream_generate_chapter(req: GenerateChapterRequest):
|
|
provider = _get_provider(req.modelId)
|
|
context_text = _search_material_context(req.chapterTitle, req.selectedFileIds, 8)
|
|
|
|
prompt = (
|
|
f"你是工程报告撰写专家。请根据以下工程素材,按照章节要求撰写报告内容。\n\n"
|
|
f"## 章节要求\n{req.chapterTitle}\n\n"
|
|
f"## 参考素材\n{context_text}\n\n"
|
|
f"## 输出要求\n"
|
|
f"1. 使用 Markdown 格式\n"
|
|
f"2. 引用素材时使用 [N] 标注\n"
|
|
f"3. 内容专业、结构清晰\n"
|
|
f"4. 包含具体数据和分析结论"
|
|
)
|
|
messages = [
|
|
{"role": "system", "content": (
|
|
"你是一位资深工程报告撰写专家,擅长根据工程素材生成结构化的技术报告章节。\n"
|
|
"重要约束:只能使用参考素材中的数据和信息。如果素材不足以支撑某个章节,请明确标注"
|
|
"『素材中未包含此部分数据,需补充』,绝不可编造数值或结论。"
|
|
)},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
async def event_gen():
|
|
import json
|
|
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
|
|
yield {"data": json.dumps(chunk, ensure_ascii=False)}
|
|
|
|
return EventSourceResponse(event_gen())
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# SSE Event Stream (material status broadcast)
|
|
# ══════════════════════════════════════════════
|
|
|
|
# Simple in-memory event bus for material processing status
|
|
_event_subscribers: list[asyncio.Queue] = []
|
|
|
|
|
|
@app.get("/api/events")
|
|
async def event_stream():
|
|
"""SSE endpoint for real-time events (material status, etc.)."""
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
_event_subscribers.append(queue)
|
|
|
|
async def gen():
|
|
try:
|
|
while True:
|
|
event = await queue.get()
|
|
yield {"data": event}
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_event_subscribers.remove(queue)
|
|
|
|
return EventSourceResponse(gen())
|
|
|
|
|
|
def _broadcast_event(event_data: str):
|
|
for q in _event_subscribers:
|
|
q.put_nowait(event_data)
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Internal helpers
|
|
# ══════════════════════════════════════════════
|
|
|
|
def _get_provider(model_id: str) -> dict:
|
|
with db.global_session() as s:
|
|
p = s.query(LLMProvider).filter_by(id=model_id, enabled=True).first()
|
|
if not p:
|
|
raise HTTPException(400, f"模型 {model_id} 未找到或已禁用")
|
|
return {"url": p.base_url, "key": p.api_key, "model": p.model_id, "provider": p.provider}
|
|
|
|
|
|
def _get_embedding_config() -> dict | None:
|
|
with db.global_session() as s:
|
|
cfg = s.query(EmbeddingModelConfig).first()
|
|
if cfg and cfg.enabled and cfg.base_url:
|
|
return {"base_url": cfg.base_url, "model": cfg.model_id,
|
|
"api_key": cfg.api_key, "provider": cfg.provider,
|
|
"dimensions": cfg.dimensions or 1024}
|
|
return None
|
|
|
|
|
|
def _search_material_context(query: str, file_ids: list[str], top_k: int) -> str:
|
|
"""Hybrid search: vector + keyword in parallel, merge and deduplicate.
|
|
|
|
Vector search captures semantic similarity; keyword search captures
|
|
exact term matches regardless of word order — together they fix the
|
|
issue where '水域集体所有' hits but '集体所有水域' misses.
|
|
"""
|
|
pid = db.current_project_id
|
|
if not pid:
|
|
return ""
|
|
|
|
vector_chunks = []
|
|
keyword_chunks = []
|
|
|
|
# ── Vector search ──
|
|
emb_cfg = _get_embedding_config()
|
|
if emb_cfg and vector_store.connected:
|
|
try:
|
|
import asyncio
|
|
import concurrent.futures
|
|
chunks_coro = rag_service.search_context(
|
|
pid, query, top_k, emb_cfg, file_ids or None,
|
|
)
|
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
future = pool.submit(asyncio.run, chunks_coro)
|
|
vector_chunks = future.result()
|
|
except Exception as e:
|
|
logger.warning("Vector search failed: %s", e)
|
|
|
|
# ── Keyword search (always run in parallel) ──
|
|
session = db.project_session()
|
|
if session:
|
|
with session as s:
|
|
keyword_chunks = search_text_chunks_keyword(
|
|
s, pid, query, file_ids or None, top_k,
|
|
)
|
|
|
|
# ── Merge and deduplicate ──
|
|
merged = _merge_search_results(vector_chunks, keyword_chunks, top_k)
|
|
|
|
if not merged:
|
|
return ""
|
|
return "\n\n---\n\n".join(c.get("text", "") for c in merged)
|
|
|
|
|
|
def _merge_search_results(vector_chunks: list[dict],
|
|
keyword_chunks: list[dict],
|
|
top_k: int) -> list[dict]:
|
|
"""Merge vector and keyword results, deduplicate by text content.
|
|
|
|
Priority: vector results first (semantically ranked), then keyword
|
|
results that weren't already found by vector search.
|
|
"""
|
|
seen_texts: set = set()
|
|
merged: list[dict] = []
|
|
|
|
def _text_key(text: str) -> str:
|
|
"""Normalize text for dedup: strip whitespace, take first 80 chars."""
|
|
return text.strip()[:80] if text else ""
|
|
|
|
# Vector results first (higher priority)
|
|
for c in vector_chunks:
|
|
key = _text_key(c.get("text", ""))
|
|
if key and key not in seen_texts:
|
|
seen_texts.add(key)
|
|
merged.append(c)
|
|
|
|
# Keyword results fill remaining slots
|
|
for c in keyword_chunks:
|
|
if len(merged) >= top_k:
|
|
break
|
|
key = _text_key(c.get("text", ""))
|
|
if key and key not in seen_texts:
|
|
seen_texts.add(key)
|
|
merged.append(c)
|
|
|
|
return merged[:top_k]
|
|
|
|
|
|
async def _parse_and_index(project_id: str, file_id: str, file_name: str,
|
|
file_type: str, file_path: str):
|
|
"""Background task: parse file, store chunks, optionally vector index."""
|
|
import json
|
|
|
|
session = db.project_session()
|
|
if not session:
|
|
return
|
|
|
|
def emit(status: str, step: str):
|
|
_broadcast_event(json.dumps({
|
|
"type": "material_status_update",
|
|
"fileId": file_id, "fileName": file_name,
|
|
"status": status, "step": step,
|
|
}, ensure_ascii=False))
|
|
|
|
try:
|
|
# Update status
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "processing"
|
|
s.commit()
|
|
|
|
emit("processing", f"📄 开始处理「{file_name}」")
|
|
emit("processing", f"📖 正在解析 {file_type.upper()} 文件内容...")
|
|
|
|
result = parse_file(file_path)
|
|
content = result.get("markdown", "")
|
|
if not content:
|
|
emit("error", f"解析失败: {result.get('error', '无内容')}")
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "error"
|
|
s.commit()
|
|
return
|
|
|
|
emit("processing", f"✓ 文件解析完成 (提取 {len(content)} 字符)")
|
|
|
|
# Save parsed content
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.parsed_content = content
|
|
s.commit()
|
|
|
|
# Chunk text
|
|
if file_type == "excel":
|
|
raw_chunks = parse_excel_to_chunks(file_path, start_row=None)
|
|
text_chunks = [c['content'] for c in raw_chunks]
|
|
metadata_list = [c['metadata'] for c in raw_chunks]
|
|
else:
|
|
text_chunks = chunk_text(content, 800, 80)
|
|
metadata_list = None
|
|
|
|
emit("processing", f"🔪 文本分块: 共 {len(text_chunks)} 个片段")
|
|
|
|
# Store chunks in SQLite
|
|
with session as s:
|
|
for i, chunk in enumerate(text_chunks):
|
|
s.add(TextChunk(id=f"{file_id}-chunk-{i}", project_id=project_id,
|
|
source_id=file_id, content=chunk, chunk_idx=i))
|
|
s.commit()
|
|
|
|
# Optional: vector indexing (use pre-chunked text instead of re-chunking)
|
|
emb_cfg = _get_embedding_config()
|
|
logger.info("Vector indexing check: emb_cfg=%s, connected=%s",
|
|
{k: v for k, v in emb_cfg.items() if k != 'api_key'} if emb_cfg else None,
|
|
vector_store.connected)
|
|
if emb_cfg and vector_store.connected:
|
|
emit("processing", f"🔗 开始向量化索引 ({len(text_chunks)} 分块)...")
|
|
try:
|
|
dim = emb_cfg.get("dimensions", 1024)
|
|
logger.info("Calling index_chunks: project=%s, source=%s, chunks=%d, dim=%d",
|
|
project_id, file_id, len(text_chunks), dim)
|
|
|
|
def on_progress(done: int, total: int):
|
|
emit("processing", f"🔗 向量化进度: {done}/{total}")
|
|
|
|
await rag_service.index_chunks(
|
|
project_id, file_id, text_chunks, emb_cfg,
|
|
dim=dim, metadata_list=metadata_list,
|
|
on_progress=on_progress,
|
|
)
|
|
emit("done", "✓ 素材完全就绪,向量索引成功")
|
|
except Exception as e:
|
|
logger.exception("Vector indexing failed for %s", file_id)
|
|
emit("done", f"⚠️ 素材已就绪,向量索引失败: {e}")
|
|
else:
|
|
emit("done", f"✓ 素材就绪 ({len(text_chunks)} 分块)")
|
|
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "done"
|
|
s.commit()
|
|
|
|
except Exception as e:
|
|
logger.exception("Parse and index failed for %s", file_name)
|
|
emit("error", f"处理失败: {e}")
|
|
with session as s:
|
|
sf = s.query(SourceFile).filter_by(id=file_id).first()
|
|
if sf:
|
|
sf.vector_status = "error"
|
|
s.commit()
|
|
|
|
|
|
# ══════════════════════════════════════════════
|
|
# Entry point
|
|
# ══════════════════════════════════════════════
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
port = int(os.environ.get("ENGIMIND_PORT", "9231"))
|
|
uvicorn.run(app, host="127.0.0.1", port=port, log_level="info")
|