362 lines
13 KiB
Python
362 lines
13 KiB
Python
"""Embedding service + Qdrant vector store + RAG orchestration."""
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import logging
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
from qdrant_client import QdrantClient
|
||
from qdrant_client.models import (
|
||
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchAny,
|
||
)
|
||
|
||
logger = logging.getLogger("engimind.vector")
|
||
|
||
|
||
# ────────── Embedding ──────────
|
||
|
||
class EmbeddingService:
|
||
def __init__(self):
|
||
self._client = httpx.AsyncClient(timeout=60.0)
|
||
|
||
async def get_embedding(self, text: str, base_url: str, model: str,
|
||
api_key: str, provider: str) -> list[float]:
|
||
if provider == "Ollama":
|
||
return await self._ollama(text, base_url, model)
|
||
return await self._openai(text, base_url, model, api_key)
|
||
|
||
async def _ollama(self, text: str, base_url: str, model: str) -> list[float]:
|
||
resp = await self._client.post(f"{base_url}/api/embeddings",
|
||
json={"model": model, "prompt": text})
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return data.get("embedding", [])
|
||
|
||
async def _openai(self, text: str, base_url: str, model: str, api_key: str) -> list[float]:
|
||
url = base_url.rstrip("/")
|
||
url = url + "/embeddings" if url.endswith("/v1") else url + "/v1/embeddings"
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
# Dashscope requires non-empty input; truncate to stay within token limits
|
||
if not text or not text.strip():
|
||
logger.warning("Skipping empty text for embedding")
|
||
return []
|
||
# Safety: truncate to ~6000 chars (~8192 tokens for Chinese)
|
||
if len(text) > 6000:
|
||
logger.warning("Truncating embedding input from %d to 6000 chars", len(text))
|
||
text = text[:6000]
|
||
resp = await self._client.post(url, json={"model": model, "input": text}, headers=headers)
|
||
if resp.status_code != 200:
|
||
body = resp.text[:500]
|
||
logger.error("Embedding API %d: %s (text[:80]=%r)", resp.status_code, body, text[:80])
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
if data.get("data") and data["data"][0].get("embedding"):
|
||
return data["data"][0]["embedding"]
|
||
return []
|
||
|
||
async def close(self):
|
||
await self._client.aclose()
|
||
|
||
|
||
# ────────── Text Chunking ──────────
|
||
|
||
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
||
runes = list(text)
|
||
if len(runes) <= chunk_size:
|
||
return [text]
|
||
chunks = []
|
||
start = 0
|
||
while start < len(runes):
|
||
end = min(start + chunk_size, len(runes))
|
||
chunks.append("".join(runes[start:end]))
|
||
start += chunk_size - overlap
|
||
return chunks
|
||
|
||
|
||
# ────────── Qdrant Store ──────────
|
||
|
||
class VectorStore:
|
||
def __init__(self):
|
||
self._client: Optional[QdrantClient] = None
|
||
|
||
def connect(self, endpoint: str):
|
||
"""Connect to Qdrant. Endpoint like 'http://localhost:6333'."""
|
||
self.disconnect()
|
||
self._client = QdrantClient(url=endpoint, timeout=10)
|
||
logger.info("Connected to Qdrant at %s", endpoint)
|
||
|
||
def disconnect(self):
|
||
if self._client:
|
||
self._client.close()
|
||
self._client = None
|
||
|
||
@property
|
||
def connected(self) -> bool:
|
||
return self._client is not None
|
||
|
||
def collection_name(self, project_id: str) -> str:
|
||
return f"engimind_{project_id}"
|
||
|
||
def ensure_collection(self, project_id: str, dim: int = 1024):
|
||
if not self._client:
|
||
return
|
||
name = self.collection_name(project_id)
|
||
if self._client.collection_exists(name):
|
||
info = self._client.get_collection(name)
|
||
existing_dim = info.config.params.vectors.size
|
||
if existing_dim != dim:
|
||
logger.warning("Collection %s dim mismatch (%d vs %d), recreating",
|
||
name, existing_dim, dim)
|
||
self._client.delete_collection(name)
|
||
else:
|
||
return
|
||
self._client.create_collection(
|
||
collection_name=name,
|
||
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
||
)
|
||
logger.info("Created collection %s (dim=%d)", name, dim)
|
||
|
||
def insert(self, project_id: str, chunks: list[dict]):
|
||
"""Insert chunks: [{'id': str, 'source_id': str, 'text': str, 'vector': list, 'metadata': dict?}].
|
||
|
||
Metadata keys 'sheet', 'row_number', 'file_name' are promoted to
|
||
payload top-level for Qdrant filter support.
|
||
"""
|
||
if not self._client:
|
||
return
|
||
name = self.collection_name(project_id)
|
||
points = []
|
||
for c in chunks:
|
||
payload = {"text": c["text"], "source_id": c["source_id"]}
|
||
if "metadata" in c:
|
||
meta = c["metadata"]
|
||
# Promote key fields to top level for Qdrant filtering
|
||
for key in ("sheet", "row_number", "file_name"):
|
||
if key in meta:
|
||
payload[key] = meta[key]
|
||
payload["metadata"] = meta
|
||
points.append(PointStruct(id=c["id"], vector=c["vector"], payload=payload))
|
||
self._client.upsert(collection_name=name, points=points)
|
||
|
||
def search(self, project_id: str, query_vec: list[float], top_k: int = 5,
|
||
file_ids: list[str] | None = None) -> list[dict]:
|
||
if not self._client:
|
||
return []
|
||
name = self.collection_name(project_id)
|
||
query_filter = None
|
||
if file_ids:
|
||
query_filter = Filter(must=[
|
||
FieldCondition(key="source_id", match=MatchAny(any=file_ids))
|
||
])
|
||
results = self._client.query_points(
|
||
collection_name=name,
|
||
query=query_vec,
|
||
limit=top_k,
|
||
query_filter=query_filter,
|
||
with_payload=True,
|
||
).points
|
||
return [
|
||
{
|
||
"text": r.payload.get("text", ""),
|
||
"source_id": r.payload.get("source_id", ""),
|
||
"score": r.score,
|
||
"metadata": r.payload.get("metadata", {}),
|
||
}
|
||
for r in results
|
||
]
|
||
|
||
def delete_by_source(self, project_id: str, source_id: str):
|
||
if not self._client:
|
||
return
|
||
name = self.collection_name(project_id)
|
||
self._client.delete(
|
||
collection_name=name,
|
||
points_selector=Filter(must=[
|
||
FieldCondition(key="source_id", match=MatchAny(any=[source_id]))
|
||
]),
|
||
)
|
||
|
||
def test_connection(self) -> bool:
|
||
if not self._client:
|
||
return False
|
||
try:
|
||
self._client.get_collections()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# ────────── RAG Service ──────────
|
||
|
||
class RAGService:
|
||
def __init__(self, embedding: EmbeddingService, store: VectorStore):
|
||
self.embedding = embedding
|
||
self.store = store
|
||
|
||
async def index_document(self, project_id: str, source_id: str, content: str,
|
||
emb_config: dict):
|
||
"""Chunk, embed, and index a document."""
|
||
text_chunks = chunk_text(content, 500, 50)
|
||
dim = emb_config.get("dimensions", 1024)
|
||
await self.index_chunks(project_id, source_id, text_chunks, emb_config, dim=dim)
|
||
|
||
async def index_chunks(self, project_id: str, source_id: str,
|
||
text_chunks: list[str], emb_config: dict,
|
||
dim: int = 1024,
|
||
metadata_list: list[dict] | None = None,
|
||
on_progress: callable | None = None):
|
||
"""Embed and index pre-chunked text with optional per-chunk metadata."""
|
||
self.store.ensure_collection(project_id, dim=dim)
|
||
total = len(text_chunks)
|
||
logger.info("index_chunks: %d chunks for source %s (dim=%d)", total, source_id, dim)
|
||
|
||
batch: list[dict] = []
|
||
batch_size = 20 # Insert in batches to reduce Qdrant round-trips
|
||
|
||
for i, text in enumerate(text_chunks):
|
||
if not text or not text.strip():
|
||
continue
|
||
vec = await self.embedding.get_embedding(
|
||
text, emb_config["base_url"], emb_config["model"],
|
||
emb_config["api_key"], emb_config["provider"],
|
||
)
|
||
if not vec:
|
||
logger.warning("Empty vector for chunk %d, skipping", i)
|
||
continue
|
||
chunk_id = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{source_id}-chunk-{i}"))
|
||
entry: dict = {"id": chunk_id, "source_id": source_id, "text": text, "vector": vec}
|
||
if metadata_list and i < len(metadata_list):
|
||
entry["metadata"] = metadata_list[i]
|
||
batch.append(entry)
|
||
|
||
# Flush batch
|
||
if len(batch) >= batch_size:
|
||
self.store.insert(project_id, batch)
|
||
batch = []
|
||
|
||
# Report progress
|
||
if on_progress and (i + 1) % 10 == 0:
|
||
on_progress(i + 1, total)
|
||
|
||
# Final batch
|
||
if batch:
|
||
self.store.insert(project_id, batch)
|
||
|
||
inserted = total # approximate
|
||
logger.info("Inserted vectors for source %s (%d chunks)", source_id, inserted)
|
||
|
||
async def search_context(self, project_id: str, question: str, top_k: int,
|
||
emb_config: dict, file_ids: list[str] | None = None) -> list[dict]:
|
||
"""Search for relevant chunks."""
|
||
query_vec = await self.embedding.get_embedding(
|
||
question, emb_config["base_url"], emb_config["model"],
|
||
emb_config["api_key"], emb_config["provider"],
|
||
)
|
||
return self.store.search(project_id, query_vec, top_k, file_ids)
|
||
|
||
|
||
# ────────── SQLite keyword search (scored ranking) ──────────
|
||
|
||
def search_text_chunks_keyword(session, project_id: str, query: str,
|
||
file_ids: list[str] | None, top_k: int) -> list[dict]:
|
||
"""Keyword search with scored ranking.
|
||
|
||
1. Extract keywords from query (Chinese-aware n-gram splitting)
|
||
2. Fetch chunks matching ANY keyword (OR, wide net)
|
||
3. Score each chunk by how many keywords it contains
|
||
4. Return top_k sorted by score descending
|
||
"""
|
||
from models import TextChunk
|
||
|
||
keywords = _extract_keywords(query)
|
||
if not keywords:
|
||
return []
|
||
|
||
q = session.query(TextChunk).filter(TextChunk.project_id == project_id)
|
||
if file_ids:
|
||
q = q.filter(TextChunk.source_id.in_(file_ids))
|
||
|
||
from sqlalchemy import or_
|
||
conditions = [TextChunk.content.ilike(f"%{kw}%") for kw in keywords]
|
||
if not conditions:
|
||
return []
|
||
|
||
# Fetch wider pool, then rank by keyword hit count
|
||
fetch_limit = max(top_k * 4, 20)
|
||
candidates = q.filter(or_(*conditions)).limit(fetch_limit).all()
|
||
|
||
# Score each chunk: count how many keywords appear in its content
|
||
scored = []
|
||
for c in candidates:
|
||
text_lower = c.content.lower()
|
||
hits = sum(1 for kw in keywords if kw.lower() in text_lower)
|
||
scored.append((hits, c))
|
||
|
||
# Sort by hit count descending, take top_k
|
||
scored.sort(key=lambda x: x[0], reverse=True)
|
||
return [
|
||
{"text": c.content, "source_id": c.source_id}
|
||
for _, c in scored[:top_k]
|
||
]
|
||
|
||
|
||
def _extract_keywords(query: str) -> list[str]:
|
||
"""Extract search keywords from a Chinese query (no jieba needed).
|
||
|
||
Strategy:
|
||
1. Remove stop words / particles
|
||
2. Split on punctuation and whitespace
|
||
3. For each segment, generate 2-4 char n-grams for Chinese text
|
||
4. Deduplicate and return
|
||
"""
|
||
import re
|
||
|
||
# Common stop words / particles
|
||
stop_chars = set("的了是在和与对有不这那我你它们都吗呢吧啊哦呀嘛")
|
||
stop_words = {"多少", "什么", "怎么", "如何", "哪些", "哪个", "请问",
|
||
"告诉", "可以", "一下", "一共", "总共", "分别"}
|
||
|
||
# Remove stop characters
|
||
cleaned = "".join(c for c in query if c not in stop_chars)
|
||
|
||
# Split on punctuation, whitespace, and non-CJK characters
|
||
segments = re.split(r'[,。、?!,.?!\s\n\t::;\-—()()\[\]【】{}""\']+', cleaned)
|
||
|
||
keywords: list[str] = []
|
||
seen: set = set()
|
||
|
||
def _add(kw: str):
|
||
if kw and len(kw) >= 2 and kw not in seen and kw not in stop_words:
|
||
seen.add(kw)
|
||
keywords.append(kw)
|
||
|
||
for seg in segments:
|
||
seg = seg.strip()
|
||
if not seg:
|
||
continue
|
||
|
||
# If segment is short enough, keep as-is
|
||
if len(seg) <= 4:
|
||
_add(seg)
|
||
continue
|
||
|
||
# For longer segments, generate overlapping n-grams (2, 3, 4 chars)
|
||
# Also keep the full segment for exact matching
|
||
_add(seg)
|
||
for n in (4, 3, 2):
|
||
for i in range(len(seg) - n + 1):
|
||
_add(seg[i:i + n])
|
||
|
||
return keywords
|
||
|
||
|
||
# ── Singletons ──
|
||
|
||
embedding_service = EmbeddingService()
|
||
vector_store = VectorStore()
|
||
rag_service = RAGService(embedding_service, vector_store)
|