Files
AI-Writie-Assistant/server/main.py
T
2026-04-28 10:46:56 +08:00

1074 lines
41 KiB
Python

"""
Engimind Python Backend — FastAPI application
All API routes matching the original Go/Wails backend functionality.
"""
from __future__ import annotations
import asyncio
import base64
import logging
import os
import shutil
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from database import db
from models import (
LLMProvider, VectorDBConfig, EmbeddingModelConfig, Project,
SourceFile, ChatMessage, TemplateChapter, TextChunk, DeliveryStandard,
LLMProviderSchema, VectorDBConfigSchema, EmbeddingConfigSchema,
ProjectSchema, SourceFileSchema, ChatMessageSchema,
TemplateChapterSchema, DeliveryStandardSchema,
)
from llm_client import llm_client
from vector_service import (
embedding_service, vector_store, rag_service,
chunk_text, search_text_chunks_keyword,
)
from parsers.registry import (
parse_file, detect_file_type, categorize_file, get_file_size,
SUPPORTED_EXTENSIONS,
)
from parsers.excel_parser import parse_excel_to_chunks, pre_parse_excel
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("engimind")
@asynccontextmanager
async def lifespan(app: FastAPI):
db.init_global()
logger.info("Engimind backend started")
# Try connecting to Qdrant if configured
with db.global_session() as s:
vdb = s.query(VectorDBConfig).first()
if vdb and vdb.endpoint:
try:
vector_store.connect(vdb.endpoint)
except Exception as e:
logger.warning("Qdrant connection failed at startup: %s", e)
# Auto-open the most recent project so API calls work immediately
with db.global_session() as s:
latest = s.query(Project).order_by(Project.updated_at.desc()).first()
if latest and latest.path:
try:
db.open_project(latest.id, latest.path)
logger.info("Auto-opened project: %s (%s)", latest.name, latest.id)
except Exception as e:
logger.warning("Failed to auto-open project: %s", e)
yield
await llm_client.close()
await embedding_service.close()
vector_store.disconnect()
logger.info("Engimind backend stopped")
app = FastAPI(title="Engimind", version="2.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ══════════════════════════════════════════════
# Health
# ══════════════════════════════════════════════
@app.get("/health")
async def health():
return {"status": "ok"}
# ══════════════════════════════════════════════
# Config: LLM Providers
# ══════════════════════════════════════════════
@app.get("/api/providers", response_model=list[LLMProviderSchema])
async def get_all_providers():
with db.global_session() as s:
rows = s.query(LLMProvider).all()
return [LLMProviderSchema.from_orm_obj(r) for r in rows]
@app.post("/api/providers")
async def save_provider(p: LLMProviderSchema):
with db.global_session() as s:
obj = s.query(LLMProvider).filter_by(id=p.id).first()
if obj:
obj.name, obj.provider, obj.base_url = p.name, p.provider, p.url
obj.api_key, obj.model_id, obj.enabled = p.key, p.model, p.enabled
else:
obj = LLMProvider(id=p.id, name=p.name, provider=p.provider,
base_url=p.url, api_key=p.key, model_id=p.model,
enabled=p.enabled)
s.add(obj)
s.commit()
return {"ok": True}
@app.delete("/api/providers/{provider_id}")
async def delete_provider(provider_id: str):
with db.global_session() as s:
s.query(LLMProvider).filter_by(id=provider_id).delete()
s.commit()
return {"ok": True}
# ══════════════════════════════════════════════
# Config: VectorDB
# ══════════════════════════════════════════════
@app.get("/api/vector-db", response_model=VectorDBConfigSchema)
async def get_vector_db():
with db.global_session() as s:
cfg = s.query(VectorDBConfig).first()
if not cfg:
return VectorDBConfigSchema()
return VectorDBConfigSchema(endpoint=cfg.endpoint, apiKey=cfg.api_key, status=cfg.status)
@app.post("/api/vector-db")
async def save_vector_db(c: VectorDBConfigSchema):
with db.global_session() as s:
cfg = s.query(VectorDBConfig).first()
if cfg:
cfg.endpoint, cfg.api_key, cfg.status = c.endpoint, c.apiKey, c.status
else:
s.add(VectorDBConfig(id=1, endpoint=c.endpoint, api_key=c.apiKey, status=c.status))
s.commit()
return {"ok": True}
@app.post("/api/vector-db/test")
async def test_vector_db(body: dict):
endpoint = body.get("endpoint", "")
try:
from qdrant_client import QdrantClient
client = QdrantClient(url=endpoint, timeout=5)
client.get_collections()
client.close()
return {"ok": True}
except Exception as e:
raise HTTPException(400, f"连接失败: {e}")
# ══════════════════════════════════════════════
# Config: Embedding
# ══════════════════════════════════════════════
@app.get("/api/embedding", response_model=EmbeddingConfigSchema)
async def get_embedding_config():
with db.global_session() as s:
cfg = s.query(EmbeddingModelConfig).first()
if not cfg:
return EmbeddingConfigSchema()
return EmbeddingConfigSchema.from_orm_obj(cfg)
@app.post("/api/embedding")
async def save_embedding_config(c: EmbeddingConfigSchema):
with db.global_session() as s:
cfg = s.query(EmbeddingModelConfig).first()
if cfg:
cfg.provider, cfg.base_url, cfg.api_key = c.provider, c.url, c.key
cfg.model_id, cfg.enabled = c.model, c.enabled
cfg.dimensions = c.dimensions
else:
s.add(EmbeddingModelConfig(id=1, provider=c.provider, base_url=c.url,
api_key=c.key, model_id=c.model, enabled=c.enabled,
dimensions=c.dimensions))
s.commit()
return {"ok": True}
@app.post("/api/embedding/test")
async def test_embedding(body: dict):
try:
vec = await embedding_service.get_embedding(
"hello", body["url"], body["model"], body.get("key", ""), body["provider"])
if not vec:
raise HTTPException(400, "返回的向量为空")
return {"ok": True, "dim": len(vec)}
except HTTPException:
raise
except Exception as e:
raise HTTPException(400, f"Embedding 测试失败: {e}")
@app.post("/api/llm/test")
async def test_llm(body: dict):
provider = body.get("provider", "")
base_url = body.get("url", "").rstrip("/")
api_key = body.get("key", "")
import httpx
if provider == "Ollama":
url = base_url + "/api/tags"
elif provider in ("DeepSeek", "OpenAI"):
url = base_url + "/models"
else:
url = base_url + "/v1/models"
headers = {}
if api_key and provider != "Ollama":
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=8) as client:
resp = await client.get(url, headers=headers)
if 200 <= resp.status_code < 400:
return {"ok": True}
if resp.status_code == 401:
raise HTTPException(400, "API Key 无效")
raise HTTPException(400, f"HTTP {resp.status_code}")
except httpx.HTTPError as e:
raise HTTPException(400, f"网络连通性异常: {e}")
# ══════════════════════════════════════════════
# Projects
# ══════════════════════════════════════════════
@app.get("/api/projects", response_model=list[ProjectSchema])
async def list_projects():
with db.global_session() as s:
return [ProjectSchema.model_validate(p) for p in
s.query(Project).order_by(Project.created_at.desc()).all()]
@app.post("/api/projects", response_model=ProjectSchema)
async def create_project(body: dict):
name = body.get("name", "未命名项目")
pid = f"p-{int(time.time() * 1000)}"
proj_dir = db.projects_base_dir() / pid
proj_dir.mkdir(parents=True, exist_ok=True)
db_path = str(proj_dir / "project.db")
with db.global_session() as s:
proj = Project(id=pid, name=name, path=db_path)
s.add(proj)
s.commit()
db.open_project(pid, db_path)
return ProjectSchema(id=pid, name=name)
@app.post("/api/projects/{project_id}/switch")
async def switch_project(project_id: str):
with db.global_session() as s:
proj = s.query(Project).filter_by(id=project_id).first()
if not proj:
raise HTTPException(404, "项目不存在")
db.open_project(proj.id, proj.path)
return {"ok": True}
@app.delete("/api/projects/{project_id}")
async def delete_project(project_id: str):
if db.current_project_id == project_id:
db.close_project()
with db.global_session() as s:
proj = s.query(Project).filter_by(id=project_id).first()
if proj:
proj_dir = os.path.dirname(proj.path)
if os.path.isdir(proj_dir):
shutil.rmtree(proj_dir, ignore_errors=True)
s.delete(proj)
s.commit()
return {"ok": True}
# ══════════════════════════════════════════════
# Materials
# ══════════════════════════════════════════════
@app.get("/api/materials", response_model=list[SourceFileSchema])
async def get_project_files():
session = db.project_session()
if not session or not db.current_project_id:
return []
with session as s:
files = s.query(SourceFile).filter_by(project_id=db.current_project_id)\
.order_by(SourceFile.created_at.desc()).all()
return [SourceFileSchema.from_orm_obj(f) for f in files]
@app.get("/api/materials/{file_id}/content")
async def get_material_content(file_id: str):
"""Return file content for online preview.
- Excel: structured JSON with sheets -> rows (for table rendering)
- Others: markdown text
"""
session = db.project_session()
if not session:
raise HTTPException(400, "项目数据库未就绪")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if not sf:
raise HTTPException(404, "素材未找到")
file_path = sf.file_path
file_type = sf.type
cached_content = sf.parsed_content or ""
if not file_path or not os.path.isfile(file_path):
raise HTTPException(400, "文件不存在")
if file_type == "excel":
# Return structured table data for rich rendering
from parsers.excel_parser import _iter_sheets, _cell_str
sheets = []
for sheet_name, grid, _ in _iter_sheets(file_path):
if not grid:
continue
rows = []
for row in grid:
rows.append([_cell_str(c) for c in row])
sheets.append({"name": sheet_name, "rows": rows})
return {"type": "excel", "sheets": sheets}
else:
# For non-excel: return cached markdown or re-parse
if cached_content:
return {"type": file_type, "content": cached_content}
result = parse_file(file_path)
return {"type": file_type, "content": result.get("markdown", "")}
class UploadRequest(BaseModel):
filePaths: list[str]
@app.post("/api/materials/upload", response_model=list[SourceFileSchema])
async def upload_materials(req: UploadRequest):
"""Upload materials by local file paths (Electron sends paths from dialog)."""
pid = db.current_project_id
if not pid:
raise HTTPException(400, "请先创建或选择一个项目")
session = db.project_session()
if not session:
raise HTTPException(400, "项目数据库未就绪")
files_dir = db.projects_base_dir() / pid / "files"
files_dir.mkdir(parents=True, exist_ok=True)
results = []
for i, src_path in enumerate(req.filePaths):
if not os.path.isfile(src_path):
continue
file_name = os.path.basename(src_path)
file_id = f"f-{int(time.time() * 1000)}-{i}"
file_type = detect_file_type(src_path)
file_size = get_file_size(src_path)
dst_path = str(files_dir / f"{file_id}_{file_name}")
shutil.copy2(src_path, dst_path)
with session as s:
# Remove duplicates
existing = s.query(SourceFile).filter_by(project_id=pid, name=file_name).all()
for old in existing:
if old.file_path and os.path.isfile(old.file_path):
os.remove(old.file_path)
s.query(TextChunk).filter_by(source_id=old.id).delete()
s.delete(old)
if vector_store.connected:
vector_store.delete_by_source(pid, old.id)
sf = SourceFile(id=file_id, project_id=pid, name=file_name,
type=file_type, category=categorize_file(file_type),
file_path=dst_path, size=file_size, vector_status="pending")
s.add(sf)
s.commit()
results.append(SourceFileSchema(id=file_id, name=file_name, type=file_type,
category=categorize_file(file_type),
size=file_size, vectorStatus="pending"))
# Parse and index in background
asyncio.create_task(_parse_and_index(pid, file_id, file_name, file_type, dst_path))
return results
@app.delete("/api/materials/{file_id}")
async def delete_material(file_id: str):
session = db.project_session()
if not session:
raise HTTPException(400, "项目数据库未就绪")
pid = db.current_project_id
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if not sf:
raise HTTPException(404, "素材未找到")
if sf.file_path and os.path.isfile(sf.file_path):
os.remove(sf.file_path)
s.query(TextChunk).filter_by(source_id=file_id).delete()
s.delete(sf)
s.commit()
if vector_store.connected and pid:
vector_store.delete_by_source(pid, file_id)
return {"ok": True}
# ══════════════════════════════════════════════
# Chat
# ══════════════════════════════════════════════
@app.get("/api/chat/messages", response_model=list[ChatMessageSchema])
async def get_chat_messages():
session = db.project_session()
if not session or not db.current_project_id:
return []
with session as s:
msgs = s.query(ChatMessage).filter_by(project_id=db.current_project_id)\
.order_by(ChatMessage.created_at).all()
return [ChatMessageSchema.from_orm_obj(m) for m in msgs]
class SaveMessageRequest(BaseModel):
role: str
content: str
sources: str = ""
citations: str = ""
@app.post("/api/chat/messages")
async def save_chat_message(req: SaveMessageRequest):
session = db.project_session()
if not session or not db.current_project_id:
raise HTTPException(400, "无活动项目")
with session as s:
s.add(ChatMessage(project_id=db.current_project_id, role=req.role,
content=req.content, sources=req.sources, citations=req.citations))
s.commit()
return {"ok": True}
@app.delete("/api/chat/messages")
async def clear_chat_messages():
session = db.project_session()
if not session or not db.current_project_id:
return {"ok": True}
with session as s:
s.query(ChatMessage).filter_by(project_id=db.current_project_id).delete()
s.commit()
return {"ok": True}
class SendMessageRequest(BaseModel):
content: str
selectedFileIds: list[str] = []
modelId: str
_CHAT_SYSTEM_PROMPT = (
"你是一位专业的工程技术助手。你只能从提供的 Context 素材中提取数值和信息来回答问题。\n"
"重要约束:\n"
"1. 如果 Context 中没有与问题相关的数据,必须诚实回答『在当前素材中未找到相关数据』,严禁编造数值或结论。\n"
"2. 绝对不要联网搜索或自行猜测数据。\n"
"3. 素材中的数据路径可能与用户提问的词序不同。例如素材写『湿地 > 国家所有』而用户问『国家所有湿地』,"
"你必须识别出这是同一个属性并正确引用数值。\n"
"4. 如果素材信息不完整,应指出缺失部分,而非自行填补。\n"
"5. 引用来源时使用 [N] 标注。"
)
@app.post("/api/chat/send")
async def send_message(req: SendMessageRequest):
"""Non-streaming chat with RAG context."""
provider = _get_provider(req.modelId)
context_text = _search_material_context(req.content, req.selectedFileIds, 5)
messages = [
{"role": "system", "content": _CHAT_SYSTEM_PROMPT},
]
if context_text:
messages.append({"role": "user", "content": f"参考以下工程素材:\n\n{context_text}\n\n---\n\n用户提问:{req.content}"})
else:
messages.append({"role": "user", "content": req.content})
result = await llm_client.complete(provider["url"], provider["key"], provider["model"], messages)
if result.get("choices"):
return {"content": result["choices"][0]["message"]["content"]}
raise HTTPException(500, "模型无响应")
@app.post("/api/chat/stream")
async def stream_message(req: SendMessageRequest):
"""Streaming chat with RAG context via SSE."""
provider = _get_provider(req.modelId)
context_text = _search_material_context(req.content, req.selectedFileIds, 5)
messages = [
{"role": "system", "content": _CHAT_SYSTEM_PROMPT},
]
if context_text:
messages.append({"role": "user", "content": f"参考以下工程素材:\n\n{context_text}\n\n---\n\n用户提问:{req.content}"})
else:
messages.append({"role": "user", "content": req.content})
async def event_gen():
import json
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
yield {"data": json.dumps(chunk, ensure_ascii=False)}
return EventSourceResponse(event_gen())
# ══════════════════════════════════════════════
# Template / Delivery Standard
# ══════════════════════════════════════════════
@app.get("/api/template/chapters", response_model=list[TemplateChapterSchema])
async def get_template_chapters():
session = db.project_session()
if not session or not db.current_project_id:
return []
with session as s:
chapters = s.query(TemplateChapter).filter_by(project_id=db.current_project_id)\
.order_by(TemplateChapter.sort_order).all()
return [TemplateChapterSchema.from_orm_obj(c) for c in chapters]
@app.post("/api/template/chapters")
async def save_template_chapters(chapters: list[TemplateChapterSchema]):
session = db.project_session()
if not session or not db.current_project_id:
raise HTTPException(400, "无活动项目")
with session as s:
s.query(TemplateChapter).filter_by(project_id=db.current_project_id).delete()
for i, c in enumerate(chapters):
s.add(TemplateChapter(
id=c.id, project_id=db.current_project_id,
title=c.title, status=c.status, progress=c.progress,
content=c.content, sort_order=i,
))
s.commit()
return {"ok": True}
@app.delete("/api/template/chapters/{chapter_id}")
async def delete_template_chapter(chapter_id: str):
session = db.project_session()
if not session:
raise HTTPException(400, "无活动项目")
with session as s:
s.query(TemplateChapter).filter_by(id=chapter_id).delete()
s.commit()
return {"ok": True}
@app.post("/api/template/extract-directory")
async def stream_template_directory(body: dict):
"""Stream LLM directory extraction from delivery standard text."""
content = body.get("content", "")
model_id = body.get("modelId", "")
provider = _get_provider(model_id)
prompt = (
"你是一个工程标准的目录解析助手。请从下面提供的交付标准文本中提取完整的章节目录结构(包括主章节和小节),并以 JSON 数组的格式返回。\n\n"
"### 要求:\n"
"1. 只返回 JSON 数组,不包含其他废话或者回答前缀。\n"
"2. 必须提取所有层级:主章节和其下的小节(子章节)。\n"
"3. 输出格式必须严格符合:\n"
'[{"id": "chapter-1", "title": "1 原材料进场检验", "content": "简要描述(可选)", '
'"children": [{"id": "chapter-1-1", "title": "1.1 钢材检验", "content": ""}]}]\n'
"4. 如果某个章节没有小节,children 可以为空数组 []。\n\n"
f"### 交付标准内容:\n{content}"
)
messages = [
{"role": "system", "content": "你是一个专业的结构化数据抽取工具。你只输出合法的 JSON,不要使用 Markdown 代码块包裹,也不要给出任何其他解释。"},
{"role": "user", "content": prompt},
]
async def event_gen():
import json
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
yield {"data": json.dumps(chunk, ensure_ascii=False)}
return EventSourceResponse(event_gen())
@app.get("/api/delivery-standard", response_model=Optional[DeliveryStandardSchema])
async def get_delivery_standard():
session = db.project_session()
if not session or not db.current_project_id:
return None
with session as s:
ds = s.query(DeliveryStandard).filter_by(project_id=db.current_project_id).first()
if not ds:
return None
return DeliveryStandardSchema(fileName=ds.file_name, content=ds.content)
@app.post("/api/delivery-standard")
async def save_delivery_standard(body: DeliveryStandardSchema):
session = db.project_session()
if not session or not db.current_project_id:
raise HTTPException(400, "无活动项目")
with session as s:
s.query(DeliveryStandard).filter_by(project_id=db.current_project_id).delete()
s.add(DeliveryStandard(project_id=db.current_project_id,
file_name=body.fileName, content=body.content))
s.commit()
return {"ok": True}
@app.post("/api/parse-file")
async def parse_file_endpoint(body: dict):
"""Parse a local file and return markdown. Used for delivery standard parsing."""
file_path = body.get("filePath", "")
if not file_path or not os.path.isfile(file_path):
raise HTTPException(400, "文件不存在")
result = parse_file(file_path)
return result
# ══════════════════════════════════════════════
# Excel Pre-parse & Final Ingest
# ══════════════════════════════════════════════
class PreParseRequest(BaseModel):
fileId: str
startRow: Optional[int] = None # optional override for "refresh preview"
@app.post("/api/materials/pre-parse")
async def pre_parse_material(req: PreParseRequest):
"""Interface A: Pre-parse an Excel file and return preview data."""
session = db.project_session()
if not session:
raise HTTPException(400, "项目数据库未就绪")
with session as s:
sf = s.query(SourceFile).filter_by(id=req.fileId).first()
if not sf:
raise HTTPException(404, "素材未找到")
file_path = sf.file_path
file_type = sf.type
if file_type != "excel":
raise HTTPException(400, "仅支持 Excel 文件的预解析")
if not file_path or not os.path.isfile(file_path):
raise HTTPException(400, "文件不存在")
result = pre_parse_excel(file_path, start_row=req.startRow)
return result
class FinalIngestRequest(BaseModel):
fileId: str
startRow: int
@app.post("/api/materials/ingest")
async def final_ingest_material(req: FinalIngestRequest):
"""Interface B: Final ingest — parse with confirmed start_row and vector index."""
pid = db.current_project_id
if not pid:
raise HTTPException(400, "请先选择一个项目")
session = db.project_session()
if not session:
raise HTTPException(400, "项目数据库未就绪")
with session as s:
sf = s.query(SourceFile).filter_by(id=req.fileId).first()
if not sf:
raise HTTPException(404, "素材未找到")
file_path = sf.file_path
file_name = sf.name
file_type = sf.type
if file_type != "excel":
raise HTTPException(400, "仅支持 Excel 文件")
if not file_path or not os.path.isfile(file_path):
raise HTTPException(400, "文件不存在")
# Launch background ingest task
asyncio.create_task(_ingest_excel(pid, req.fileId, file_name, file_path, req.startRow))
return {"ok": True, "message": f"开始入库,数据起始行: {req.startRow}"}
async def _ingest_excel(project_id: str, file_id: str, file_name: str,
file_path: str, start_row: int):
"""Background: parse Excel with confirmed start_row, clear old vectors, re-index."""
import json
session = db.project_session()
if not session:
return
def emit(status: str, step: str):
_broadcast_event(json.dumps({
"type": "material_status_update",
"fileId": file_id, "fileName": file_name,
"status": status, "step": step,
}, ensure_ascii=False))
try:
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "processing"
s.commit()
emit("processing", f"📄 开始重新解析「{file_name}」(起始行: {start_row})")
# Step 1: Parse with confirmed start_row
raw_chunks = parse_excel_to_chunks(file_path, start_row=start_row)
text_chunks = [c['content'] for c in raw_chunks]
metadata_list = [c['metadata'] for c in raw_chunks]
emit("processing", f"🔪 文本分块: 共 {len(text_chunks)} 个片段")
# Step 2: Clear old data
with session as s:
s.query(TextChunk).filter_by(source_id=file_id).delete()
s.commit()
if vector_store.connected:
vector_store.delete_by_source(project_id, file_id)
emit("processing", "🗑️ 已清除旧向量数据")
# Step 3: Store chunks in SQLite
with session as s:
for i, chunk in enumerate(text_chunks):
s.add(TextChunk(id=f"{file_id}-chunk-{i}", project_id=project_id,
source_id=file_id, content=chunk, chunk_idx=i))
s.commit()
# Step 4: Vector index
emb_cfg = _get_embedding_config()
if emb_cfg and vector_store.connected:
emit("processing", f"🔗 开始向量化索引 ({len(text_chunks)} 分块)...")
dim = emb_cfg.get("dimensions", 1024)
def on_progress(done: int, total: int):
emit("processing", f"🔗 向量化进度: {done}/{total}")
await rag_service.index_chunks(
project_id, file_id, text_chunks, emb_cfg,
dim=dim, metadata_list=metadata_list,
on_progress=on_progress,
)
emit("done", "✓ 素材完全就绪,向量索引成功")
else:
emit("done", f"✓ 素材就绪 ({len(text_chunks)} 分块)")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "done"
s.commit()
except Exception as e:
logger.exception("Excel ingest failed for %s", file_name)
emit("error", f"入库失败: {e}")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "error"
s.commit()
# ══════════════════════════════════════════════
# Chapter Generation
# ══════════════════════════════════════════════
class GenerateChapterRequest(BaseModel):
chapterTitle: str
selectedFileIds: list[str] = []
modelId: str
@app.post("/api/chat/generate-chapter")
async def stream_generate_chapter(req: GenerateChapterRequest):
provider = _get_provider(req.modelId)
context_text = _search_material_context(req.chapterTitle, req.selectedFileIds, 8)
prompt = (
f"你是工程报告撰写专家。请根据以下工程素材,按照章节要求撰写报告内容。\n\n"
f"## 章节要求\n{req.chapterTitle}\n\n"
f"## 参考素材\n{context_text}\n\n"
f"## 输出要求\n"
f"1. 使用 Markdown 格式\n"
f"2. 引用素材时使用 [N] 标注\n"
f"3. 内容专业、结构清晰\n"
f"4. 包含具体数据和分析结论"
)
messages = [
{"role": "system", "content": (
"你是一位资深工程报告撰写专家,擅长根据工程素材生成结构化的技术报告章节。\n"
"重要约束:只能使用参考素材中的数据和信息。如果素材不足以支撑某个章节,请明确标注"
"『素材中未包含此部分数据,需补充』,绝不可编造数值或结论。"
)},
{"role": "user", "content": prompt},
]
async def event_gen():
import json
async for chunk in llm_client.stream_complete(provider["url"], provider["key"], provider["model"], messages):
yield {"data": json.dumps(chunk, ensure_ascii=False)}
return EventSourceResponse(event_gen())
# ══════════════════════════════════════════════
# SSE Event Stream (material status broadcast)
# ══════════════════════════════════════════════
# Simple in-memory event bus for material processing status
_event_subscribers: list[asyncio.Queue] = []
@app.get("/api/events")
async def event_stream():
"""SSE endpoint for real-time events (material status, etc.)."""
queue: asyncio.Queue = asyncio.Queue()
_event_subscribers.append(queue)
async def gen():
try:
while True:
event = await queue.get()
yield {"data": event}
except asyncio.CancelledError:
pass
finally:
_event_subscribers.remove(queue)
return EventSourceResponse(gen())
def _broadcast_event(event_data: str):
for q in _event_subscribers:
q.put_nowait(event_data)
# ══════════════════════════════════════════════
# Internal helpers
# ══════════════════════════════════════════════
def _get_provider(model_id: str) -> dict:
with db.global_session() as s:
p = s.query(LLMProvider).filter_by(id=model_id, enabled=True).first()
if not p:
raise HTTPException(400, f"模型 {model_id} 未找到或已禁用")
return {"url": p.base_url, "key": p.api_key, "model": p.model_id, "provider": p.provider}
def _get_embedding_config() -> dict | None:
with db.global_session() as s:
cfg = s.query(EmbeddingModelConfig).first()
if cfg and cfg.enabled and cfg.base_url:
return {"base_url": cfg.base_url, "model": cfg.model_id,
"api_key": cfg.api_key, "provider": cfg.provider,
"dimensions": cfg.dimensions or 1024}
return None
def _search_material_context(query: str, file_ids: list[str], top_k: int) -> str:
"""Hybrid search: vector + keyword in parallel, merge and deduplicate.
Vector search captures semantic similarity; keyword search captures
exact term matches regardless of word order — together they fix the
issue where '水域集体所有' hits but '集体所有水域' misses.
"""
pid = db.current_project_id
if not pid:
return ""
vector_chunks = []
keyword_chunks = []
# ── Vector search ──
emb_cfg = _get_embedding_config()
if emb_cfg and vector_store.connected:
try:
import asyncio
import concurrent.futures
chunks_coro = rag_service.search_context(
pid, query, top_k, emb_cfg, file_ids or None,
)
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, chunks_coro)
vector_chunks = future.result()
except Exception as e:
logger.warning("Vector search failed: %s", e)
# ── Keyword search (always run in parallel) ──
session = db.project_session()
if session:
with session as s:
keyword_chunks = search_text_chunks_keyword(
s, pid, query, file_ids or None, top_k,
)
# ── Merge and deduplicate ──
merged = _merge_search_results(vector_chunks, keyword_chunks, top_k)
if not merged:
return ""
return "\n\n---\n\n".join(c.get("text", "") for c in merged)
def _merge_search_results(vector_chunks: list[dict],
keyword_chunks: list[dict],
top_k: int) -> list[dict]:
"""Merge vector and keyword results, deduplicate by text content.
Priority: vector results first (semantically ranked), then keyword
results that weren't already found by vector search.
"""
seen_texts: set = set()
merged: list[dict] = []
def _text_key(text: str) -> str:
"""Normalize text for dedup: strip whitespace, take first 80 chars."""
return text.strip()[:80] if text else ""
# Vector results first (higher priority)
for c in vector_chunks:
key = _text_key(c.get("text", ""))
if key and key not in seen_texts:
seen_texts.add(key)
merged.append(c)
# Keyword results fill remaining slots
for c in keyword_chunks:
if len(merged) >= top_k:
break
key = _text_key(c.get("text", ""))
if key and key not in seen_texts:
seen_texts.add(key)
merged.append(c)
return merged[:top_k]
async def _parse_and_index(project_id: str, file_id: str, file_name: str,
file_type: str, file_path: str):
"""Background task: parse file, store chunks, optionally vector index."""
import json
session = db.project_session()
if not session:
return
def emit(status: str, step: str):
_broadcast_event(json.dumps({
"type": "material_status_update",
"fileId": file_id, "fileName": file_name,
"status": status, "step": step,
}, ensure_ascii=False))
try:
# Update status
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "processing"
s.commit()
emit("processing", f"📄 开始处理「{file_name}")
emit("processing", f"📖 正在解析 {file_type.upper()} 文件内容...")
result = parse_file(file_path)
content = result.get("markdown", "")
if not content:
emit("error", f"解析失败: {result.get('error', '无内容')}")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "error"
s.commit()
return
emit("processing", f"✓ 文件解析完成 (提取 {len(content)} 字符)")
# Save parsed content
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.parsed_content = content
s.commit()
# Chunk text
if file_type == "excel":
raw_chunks = parse_excel_to_chunks(file_path, start_row=None)
text_chunks = [c['content'] for c in raw_chunks]
metadata_list = [c['metadata'] for c in raw_chunks]
else:
text_chunks = chunk_text(content, 800, 80)
metadata_list = None
emit("processing", f"🔪 文本分块: 共 {len(text_chunks)} 个片段")
# Store chunks in SQLite
with session as s:
for i, chunk in enumerate(text_chunks):
s.add(TextChunk(id=f"{file_id}-chunk-{i}", project_id=project_id,
source_id=file_id, content=chunk, chunk_idx=i))
s.commit()
# Optional: vector indexing (use pre-chunked text instead of re-chunking)
emb_cfg = _get_embedding_config()
logger.info("Vector indexing check: emb_cfg=%s, connected=%s",
{k: v for k, v in emb_cfg.items() if k != 'api_key'} if emb_cfg else None,
vector_store.connected)
if emb_cfg and vector_store.connected:
emit("processing", f"🔗 开始向量化索引 ({len(text_chunks)} 分块)...")
try:
dim = emb_cfg.get("dimensions", 1024)
logger.info("Calling index_chunks: project=%s, source=%s, chunks=%d, dim=%d",
project_id, file_id, len(text_chunks), dim)
def on_progress(done: int, total: int):
emit("processing", f"🔗 向量化进度: {done}/{total}")
await rag_service.index_chunks(
project_id, file_id, text_chunks, emb_cfg,
dim=dim, metadata_list=metadata_list,
on_progress=on_progress,
)
emit("done", "✓ 素材完全就绪,向量索引成功")
except Exception as e:
logger.exception("Vector indexing failed for %s", file_id)
emit("done", f"⚠️ 素材已就绪,向量索引失败: {e}")
else:
emit("done", f"✓ 素材就绪 ({len(text_chunks)} 分块)")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "done"
s.commit()
except Exception as e:
logger.exception("Parse and index failed for %s", file_name)
emit("error", f"处理失败: {e}")
with session as s:
sf = s.query(SourceFile).filter_by(id=file_id).first()
if sf:
sf.vector_status = "error"
s.commit()
# ══════════════════════════════════════════════
# Entry point
# ══════════════════════════════════════════════
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("ENGIMIND_PORT", "9231"))
uvicorn.run(app, host="127.0.0.1", port=port, log_level="info")