"""Embedding service + Qdrant vector store + RAG orchestration.""" from __future__ import annotations import hashlib import logging import uuid from typing import Optional import httpx from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchAny, ) logger = logging.getLogger("engimind.vector") # ────────── Embedding ────────── class EmbeddingService: def __init__(self): self._client = httpx.AsyncClient(timeout=60.0) async def get_embedding(self, text: str, base_url: str, model: str, api_key: str, provider: str) -> list[float]: if provider == "Ollama": return await self._ollama(text, base_url, model) return await self._openai(text, base_url, model, api_key) async def _ollama(self, text: str, base_url: str, model: str) -> list[float]: resp = await self._client.post(f"{base_url}/api/embeddings", json={"model": model, "prompt": text}) resp.raise_for_status() data = resp.json() return data.get("embedding", []) async def _openai(self, text: str, base_url: str, model: str, api_key: str) -> list[float]: url = base_url.rstrip("/") url = url + "/embeddings" if url.endswith("/v1") else url + "/v1/embeddings" headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" # Dashscope requires non-empty input; truncate to stay within token limits if not text or not text.strip(): logger.warning("Skipping empty text for embedding") return [] # Safety: truncate to ~6000 chars (~8192 tokens for Chinese) if len(text) > 6000: logger.warning("Truncating embedding input from %d to 6000 chars", len(text)) text = text[:6000] resp = await self._client.post(url, json={"model": model, "input": text}, headers=headers) if resp.status_code != 200: body = resp.text[:500] logger.error("Embedding API %d: %s (text[:80]=%r)", resp.status_code, body, text[:80]) resp.raise_for_status() data = resp.json() if data.get("data") and data["data"][0].get("embedding"): return data["data"][0]["embedding"] return [] async def close(self): await self._client.aclose() # ────────── Text Chunking ────────── def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: runes = list(text) if len(runes) <= chunk_size: return [text] chunks = [] start = 0 while start < len(runes): end = min(start + chunk_size, len(runes)) chunks.append("".join(runes[start:end])) start += chunk_size - overlap return chunks # ────────── Qdrant Store ────────── class VectorStore: def __init__(self): self._client: Optional[QdrantClient] = None def connect(self, endpoint: str): """Connect to Qdrant. Endpoint like 'http://localhost:6333'.""" self.disconnect() self._client = QdrantClient(url=endpoint, timeout=10) logger.info("Connected to Qdrant at %s", endpoint) def disconnect(self): if self._client: self._client.close() self._client = None @property def connected(self) -> bool: return self._client is not None def collection_name(self, project_id: str) -> str: return f"engimind_{project_id}" def ensure_collection(self, project_id: str, dim: int = 1024): if not self._client: return name = self.collection_name(project_id) if self._client.collection_exists(name): info = self._client.get_collection(name) existing_dim = info.config.params.vectors.size if existing_dim != dim: logger.warning("Collection %s dim mismatch (%d vs %d), recreating", name, existing_dim, dim) self._client.delete_collection(name) else: return self._client.create_collection( collection_name=name, vectors_config=VectorParams(size=dim, distance=Distance.COSINE), ) logger.info("Created collection %s (dim=%d)", name, dim) def insert(self, project_id: str, chunks: list[dict]): """Insert chunks: [{'id': str, 'source_id': str, 'text': str, 'vector': list, 'metadata': dict?}].""" if not self._client: return name = self.collection_name(project_id) points = [] for c in chunks: payload = {"text": c["text"], "source_id": c["source_id"]} if "metadata" in c: payload["metadata"] = c["metadata"] points.append(PointStruct(id=c["id"], vector=c["vector"], payload=payload)) self._client.upsert(collection_name=name, points=points) def search(self, project_id: str, query_vec: list[float], top_k: int = 5, file_ids: list[str] | None = None) -> list[dict]: if not self._client: return [] name = self.collection_name(project_id) query_filter = None if file_ids: query_filter = Filter(must=[ FieldCondition(key="source_id", match=MatchAny(any=file_ids)) ]) results = self._client.query_points( collection_name=name, query=query_vec, limit=top_k, query_filter=query_filter, with_payload=True, ).points return [ { "text": r.payload.get("text", ""), "source_id": r.payload.get("source_id", ""), "score": r.score, "metadata": r.payload.get("metadata", {}), } for r in results ] def delete_by_source(self, project_id: str, source_id: str): if not self._client: return name = self.collection_name(project_id) self._client.delete( collection_name=name, points_selector=Filter(must=[ FieldCondition(key="source_id", match=MatchAny(any=[source_id])) ]), ) def test_connection(self) -> bool: if not self._client: return False try: self._client.get_collections() return True except Exception: return False # ────────── RAG Service ────────── class RAGService: def __init__(self, embedding: EmbeddingService, store: VectorStore): self.embedding = embedding self.store = store async def index_document(self, project_id: str, source_id: str, content: str, emb_config: dict): """Chunk, embed, and index a document.""" text_chunks = chunk_text(content, 500, 50) dim = emb_config.get("dimensions", 1024) await self.index_chunks(project_id, source_id, text_chunks, emb_config, dim=dim) async def index_chunks(self, project_id: str, source_id: str, text_chunks: list[str], emb_config: dict, dim: int = 1024, metadata_list: list[dict] | None = None, on_progress: callable | None = None): """Embed and index pre-chunked text with optional per-chunk metadata.""" self.store.ensure_collection(project_id, dim=dim) total = len(text_chunks) logger.info("index_chunks: %d chunks for source %s (dim=%d)", total, source_id, dim) batch: list[dict] = [] batch_size = 20 # Insert in batches to reduce Qdrant round-trips for i, text in enumerate(text_chunks): if not text or not text.strip(): continue vec = await self.embedding.get_embedding( text, emb_config["base_url"], emb_config["model"], emb_config["api_key"], emb_config["provider"], ) if not vec: logger.warning("Empty vector for chunk %d, skipping", i) continue chunk_id = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{source_id}-chunk-{i}")) entry: dict = {"id": chunk_id, "source_id": source_id, "text": text, "vector": vec} if metadata_list and i < len(metadata_list): entry["metadata"] = metadata_list[i] batch.append(entry) # Flush batch if len(batch) >= batch_size: self.store.insert(project_id, batch) batch = [] # Report progress if on_progress and (i + 1) % 10 == 0: on_progress(i + 1, total) # Final batch if batch: self.store.insert(project_id, batch) inserted = total # approximate logger.info("Inserted vectors for source %s (%d chunks)", source_id, inserted) async def search_context(self, project_id: str, question: str, top_k: int, emb_config: dict, file_ids: list[str] | None = None) -> list[dict]: """Search for relevant chunks.""" query_vec = await self.embedding.get_embedding( question, emb_config["base_url"], emb_config["model"], emb_config["api_key"], emb_config["provider"], ) return self.store.search(project_id, query_vec, top_k, file_ids) # ────────── SQLite fallback keyword search ────────── def search_text_chunks_keyword(session, project_id: str, query: str, file_ids: list[str] | None, top_k: int) -> list[dict]: """Keyword-based fallback when vector search is unavailable.""" from models import TextChunk keywords = _extract_keywords(query) if not keywords: return [] q = session.query(TextChunk).filter(TextChunk.project_id == project_id) if file_ids: q = q.filter(TextChunk.source_id.in_(file_ids)) from sqlalchemy import or_ conditions = [TextChunk.content.ilike(f"%{kw}%") for kw in keywords if len(kw) >= 2] if not conditions: return [] q = q.filter(or_(*conditions)).order_by(TextChunk.chunk_idx).limit(top_k) return [{"text": c.content, "source_id": c.source_id} for c in q.all()] def _extract_keywords(query: str) -> list[str]: import re stop = {"的", "了", "是", "在", "和", "与", "对", "有", "不", "这", "那", "我", "你"} parts = re.split(r'[,。、?!,.\?!\s\n\t]+', query) return [p.strip() for p in parts if len(p.strip()) >= 2 and p.strip() not in stop] # ── Singletons ── embedding_service = EmbeddingService() vector_store = VectorStore() rag_service = RAGService(embedding_service, vector_store)