Files
AI-Writie-Assistant/server/vector_service.py
T
2026-04-28 10:46:56 +08:00

362 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Embedding service + Qdrant vector store + RAG orchestration."""
from __future__ import annotations
import hashlib
import logging
import uuid
from typing import Optional
import httpx
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchAny,
)
logger = logging.getLogger("engimind.vector")
# ────────── Embedding ──────────
class EmbeddingService:
def __init__(self):
self._client = httpx.AsyncClient(timeout=60.0)
async def get_embedding(self, text: str, base_url: str, model: str,
api_key: str, provider: str) -> list[float]:
if provider == "Ollama":
return await self._ollama(text, base_url, model)
return await self._openai(text, base_url, model, api_key)
async def _ollama(self, text: str, base_url: str, model: str) -> list[float]:
resp = await self._client.post(f"{base_url}/api/embeddings",
json={"model": model, "prompt": text})
resp.raise_for_status()
data = resp.json()
return data.get("embedding", [])
async def _openai(self, text: str, base_url: str, model: str, api_key: str) -> list[float]:
url = base_url.rstrip("/")
url = url + "/embeddings" if url.endswith("/v1") else url + "/v1/embeddings"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Dashscope requires non-empty input; truncate to stay within token limits
if not text or not text.strip():
logger.warning("Skipping empty text for embedding")
return []
# Safety: truncate to ~6000 chars (~8192 tokens for Chinese)
if len(text) > 6000:
logger.warning("Truncating embedding input from %d to 6000 chars", len(text))
text = text[:6000]
resp = await self._client.post(url, json={"model": model, "input": text}, headers=headers)
if resp.status_code != 200:
body = resp.text[:500]
logger.error("Embedding API %d: %s (text[:80]=%r)", resp.status_code, body, text[:80])
resp.raise_for_status()
data = resp.json()
if data.get("data") and data["data"][0].get("embedding"):
return data["data"][0]["embedding"]
return []
async def close(self):
await self._client.aclose()
# ────────── Text Chunking ──────────
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
runes = list(text)
if len(runes) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(runes):
end = min(start + chunk_size, len(runes))
chunks.append("".join(runes[start:end]))
start += chunk_size - overlap
return chunks
# ────────── Qdrant Store ──────────
class VectorStore:
def __init__(self):
self._client: Optional[QdrantClient] = None
def connect(self, endpoint: str):
"""Connect to Qdrant. Endpoint like 'http://localhost:6333'."""
self.disconnect()
self._client = QdrantClient(url=endpoint, timeout=10)
logger.info("Connected to Qdrant at %s", endpoint)
def disconnect(self):
if self._client:
self._client.close()
self._client = None
@property
def connected(self) -> bool:
return self._client is not None
def collection_name(self, project_id: str) -> str:
return f"engimind_{project_id}"
def ensure_collection(self, project_id: str, dim: int = 1024):
if not self._client:
return
name = self.collection_name(project_id)
if self._client.collection_exists(name):
info = self._client.get_collection(name)
existing_dim = info.config.params.vectors.size
if existing_dim != dim:
logger.warning("Collection %s dim mismatch (%d vs %d), recreating",
name, existing_dim, dim)
self._client.delete_collection(name)
else:
return
self._client.create_collection(
collection_name=name,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
logger.info("Created collection %s (dim=%d)", name, dim)
def insert(self, project_id: str, chunks: list[dict]):
"""Insert chunks: [{'id': str, 'source_id': str, 'text': str, 'vector': list, 'metadata': dict?}].
Metadata keys 'sheet', 'row_number', 'file_name' are promoted to
payload top-level for Qdrant filter support.
"""
if not self._client:
return
name = self.collection_name(project_id)
points = []
for c in chunks:
payload = {"text": c["text"], "source_id": c["source_id"]}
if "metadata" in c:
meta = c["metadata"]
# Promote key fields to top level for Qdrant filtering
for key in ("sheet", "row_number", "file_name"):
if key in meta:
payload[key] = meta[key]
payload["metadata"] = meta
points.append(PointStruct(id=c["id"], vector=c["vector"], payload=payload))
self._client.upsert(collection_name=name, points=points)
def search(self, project_id: str, query_vec: list[float], top_k: int = 5,
file_ids: list[str] | None = None) -> list[dict]:
if not self._client:
return []
name = self.collection_name(project_id)
query_filter = None
if file_ids:
query_filter = Filter(must=[
FieldCondition(key="source_id", match=MatchAny(any=file_ids))
])
results = self._client.query_points(
collection_name=name,
query=query_vec,
limit=top_k,
query_filter=query_filter,
with_payload=True,
).points
return [
{
"text": r.payload.get("text", ""),
"source_id": r.payload.get("source_id", ""),
"score": r.score,
"metadata": r.payload.get("metadata", {}),
}
for r in results
]
def delete_by_source(self, project_id: str, source_id: str):
if not self._client:
return
name = self.collection_name(project_id)
self._client.delete(
collection_name=name,
points_selector=Filter(must=[
FieldCondition(key="source_id", match=MatchAny(any=[source_id]))
]),
)
def test_connection(self) -> bool:
if not self._client:
return False
try:
self._client.get_collections()
return True
except Exception:
return False
# ────────── RAG Service ──────────
class RAGService:
def __init__(self, embedding: EmbeddingService, store: VectorStore):
self.embedding = embedding
self.store = store
async def index_document(self, project_id: str, source_id: str, content: str,
emb_config: dict):
"""Chunk, embed, and index a document."""
text_chunks = chunk_text(content, 500, 50)
dim = emb_config.get("dimensions", 1024)
await self.index_chunks(project_id, source_id, text_chunks, emb_config, dim=dim)
async def index_chunks(self, project_id: str, source_id: str,
text_chunks: list[str], emb_config: dict,
dim: int = 1024,
metadata_list: list[dict] | None = None,
on_progress: callable | None = None):
"""Embed and index pre-chunked text with optional per-chunk metadata."""
self.store.ensure_collection(project_id, dim=dim)
total = len(text_chunks)
logger.info("index_chunks: %d chunks for source %s (dim=%d)", total, source_id, dim)
batch: list[dict] = []
batch_size = 20 # Insert in batches to reduce Qdrant round-trips
for i, text in enumerate(text_chunks):
if not text or not text.strip():
continue
vec = await self.embedding.get_embedding(
text, emb_config["base_url"], emb_config["model"],
emb_config["api_key"], emb_config["provider"],
)
if not vec:
logger.warning("Empty vector for chunk %d, skipping", i)
continue
chunk_id = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{source_id}-chunk-{i}"))
entry: dict = {"id": chunk_id, "source_id": source_id, "text": text, "vector": vec}
if metadata_list and i < len(metadata_list):
entry["metadata"] = metadata_list[i]
batch.append(entry)
# Flush batch
if len(batch) >= batch_size:
self.store.insert(project_id, batch)
batch = []
# Report progress
if on_progress and (i + 1) % 10 == 0:
on_progress(i + 1, total)
# Final batch
if batch:
self.store.insert(project_id, batch)
inserted = total # approximate
logger.info("Inserted vectors for source %s (%d chunks)", source_id, inserted)
async def search_context(self, project_id: str, question: str, top_k: int,
emb_config: dict, file_ids: list[str] | None = None) -> list[dict]:
"""Search for relevant chunks."""
query_vec = await self.embedding.get_embedding(
question, emb_config["base_url"], emb_config["model"],
emb_config["api_key"], emb_config["provider"],
)
return self.store.search(project_id, query_vec, top_k, file_ids)
# ────────── SQLite keyword search (scored ranking) ──────────
def search_text_chunks_keyword(session, project_id: str, query: str,
file_ids: list[str] | None, top_k: int) -> list[dict]:
"""Keyword search with scored ranking.
1. Extract keywords from query (Chinese-aware n-gram splitting)
2. Fetch chunks matching ANY keyword (OR, wide net)
3. Score each chunk by how many keywords it contains
4. Return top_k sorted by score descending
"""
from models import TextChunk
keywords = _extract_keywords(query)
if not keywords:
return []
q = session.query(TextChunk).filter(TextChunk.project_id == project_id)
if file_ids:
q = q.filter(TextChunk.source_id.in_(file_ids))
from sqlalchemy import or_
conditions = [TextChunk.content.ilike(f"%{kw}%") for kw in keywords]
if not conditions:
return []
# Fetch wider pool, then rank by keyword hit count
fetch_limit = max(top_k * 4, 20)
candidates = q.filter(or_(*conditions)).limit(fetch_limit).all()
# Score each chunk: count how many keywords appear in its content
scored = []
for c in candidates:
text_lower = c.content.lower()
hits = sum(1 for kw in keywords if kw.lower() in text_lower)
scored.append((hits, c))
# Sort by hit count descending, take top_k
scored.sort(key=lambda x: x[0], reverse=True)
return [
{"text": c.content, "source_id": c.source_id}
for _, c in scored[:top_k]
]
def _extract_keywords(query: str) -> list[str]:
"""Extract search keywords from a Chinese query (no jieba needed).
Strategy:
1. Remove stop words / particles
2. Split on punctuation and whitespace
3. For each segment, generate 2-4 char n-grams for Chinese text
4. Deduplicate and return
"""
import re
# Common stop words / particles
stop_chars = set("的了是在和与对有不这那我你它们都吗呢吧啊哦呀嘛")
stop_words = {"多少", "什么", "怎么", "如何", "哪些", "哪个", "请问",
"告诉", "可以", "一下", "一共", "总共", "分别"}
# Remove stop characters
cleaned = "".join(c for c in query if c not in stop_chars)
# Split on punctuation, whitespace, and non-CJK characters
segments = re.split(r'[,。、?!,.?!\s\n\t:;\-—()()\[\]【】{}""\']+', cleaned)
keywords: list[str] = []
seen: set = set()
def _add(kw: str):
if kw and len(kw) >= 2 and kw not in seen and kw not in stop_words:
seen.add(kw)
keywords.append(kw)
for seg in segments:
seg = seg.strip()
if not seg:
continue
# If segment is short enough, keep as-is
if len(seg) <= 4:
_add(seg)
continue
# For longer segments, generate overlapping n-grams (2, 3, 4 chars)
# Also keep the full segment for exact matching
_add(seg)
for n in (4, 3, 2):
for i in range(len(seg) - n + 1):
_add(seg[i:i + n])
return keywords
# ── Singletons ──
embedding_service = EmbeddingService()
vector_store = VectorStore()
rag_service = RAGService(embedding_service, vector_store)