refactor: excel parse
This commit is contained in:
@@ -0,0 +1,288 @@
|
||||
"""Embedding service + Qdrant vector store + RAG orchestration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchAny,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("engimind.vector")
|
||||
|
||||
|
||||
# ────────── Embedding ──────────
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self):
|
||||
self._client = httpx.AsyncClient(timeout=60.0)
|
||||
|
||||
async def get_embedding(self, text: str, base_url: str, model: str,
|
||||
api_key: str, provider: str) -> list[float]:
|
||||
if provider == "Ollama":
|
||||
return await self._ollama(text, base_url, model)
|
||||
return await self._openai(text, base_url, model, api_key)
|
||||
|
||||
async def _ollama(self, text: str, base_url: str, model: str) -> list[float]:
|
||||
resp = await self._client.post(f"{base_url}/api/embeddings",
|
||||
json={"model": model, "prompt": text})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("embedding", [])
|
||||
|
||||
async def _openai(self, text: str, base_url: str, model: str, api_key: str) -> list[float]:
|
||||
url = base_url.rstrip("/")
|
||||
url = url + "/embeddings" if url.endswith("/v1") else url + "/v1/embeddings"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# Dashscope requires non-empty input; truncate to stay within token limits
|
||||
if not text or not text.strip():
|
||||
logger.warning("Skipping empty text for embedding")
|
||||
return []
|
||||
# Safety: truncate to ~6000 chars (~8192 tokens for Chinese)
|
||||
if len(text) > 6000:
|
||||
logger.warning("Truncating embedding input from %d to 6000 chars", len(text))
|
||||
text = text[:6000]
|
||||
resp = await self._client.post(url, json={"model": model, "input": text}, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
body = resp.text[:500]
|
||||
logger.error("Embedding API %d: %s (text[:80]=%r)", resp.status_code, body, text[:80])
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if data.get("data") and data["data"][0].get("embedding"):
|
||||
return data["data"][0]["embedding"]
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
# ────────── Text Chunking ──────────
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
||||
runes = list(text)
|
||||
if len(runes) <= chunk_size:
|
||||
return [text]
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(runes):
|
||||
end = min(start + chunk_size, len(runes))
|
||||
chunks.append("".join(runes[start:end]))
|
||||
start += chunk_size - overlap
|
||||
return chunks
|
||||
|
||||
|
||||
# ────────── Qdrant Store ──────────
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self):
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
def connect(self, endpoint: str):
|
||||
"""Connect to Qdrant. Endpoint like 'http://localhost:6333'."""
|
||||
self.disconnect()
|
||||
self._client = QdrantClient(url=endpoint, timeout=10)
|
||||
logger.info("Connected to Qdrant at %s", endpoint)
|
||||
|
||||
def disconnect(self):
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def collection_name(self, project_id: str) -> str:
|
||||
return f"engimind_{project_id}"
|
||||
|
||||
def ensure_collection(self, project_id: str, dim: int = 1024):
|
||||
if not self._client:
|
||||
return
|
||||
name = self.collection_name(project_id)
|
||||
if self._client.collection_exists(name):
|
||||
info = self._client.get_collection(name)
|
||||
existing_dim = info.config.params.vectors.size
|
||||
if existing_dim != dim:
|
||||
logger.warning("Collection %s dim mismatch (%d vs %d), recreating",
|
||||
name, existing_dim, dim)
|
||||
self._client.delete_collection(name)
|
||||
else:
|
||||
return
|
||||
self._client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("Created collection %s (dim=%d)", name, dim)
|
||||
|
||||
def insert(self, project_id: str, chunks: list[dict]):
|
||||
"""Insert chunks: [{'id': str, 'source_id': str, 'text': str, 'vector': list, 'metadata': dict?}]."""
|
||||
if not self._client:
|
||||
return
|
||||
name = self.collection_name(project_id)
|
||||
points = []
|
||||
for c in chunks:
|
||||
payload = {"text": c["text"], "source_id": c["source_id"]}
|
||||
if "metadata" in c:
|
||||
payload["metadata"] = c["metadata"]
|
||||
points.append(PointStruct(id=c["id"], vector=c["vector"], payload=payload))
|
||||
self._client.upsert(collection_name=name, points=points)
|
||||
|
||||
def search(self, project_id: str, query_vec: list[float], top_k: int = 5,
|
||||
file_ids: list[str] | None = None) -> list[dict]:
|
||||
if not self._client:
|
||||
return []
|
||||
name = self.collection_name(project_id)
|
||||
query_filter = None
|
||||
if file_ids:
|
||||
query_filter = Filter(must=[
|
||||
FieldCondition(key="source_id", match=MatchAny(any=file_ids))
|
||||
])
|
||||
results = self._client.query_points(
|
||||
collection_name=name,
|
||||
query=query_vec,
|
||||
limit=top_k,
|
||||
query_filter=query_filter,
|
||||
with_payload=True,
|
||||
).points
|
||||
return [
|
||||
{
|
||||
"text": r.payload.get("text", ""),
|
||||
"source_id": r.payload.get("source_id", ""),
|
||||
"score": r.score,
|
||||
"metadata": r.payload.get("metadata", {}),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
def delete_by_source(self, project_id: str, source_id: str):
|
||||
if not self._client:
|
||||
return
|
||||
name = self.collection_name(project_id)
|
||||
self._client.delete(
|
||||
collection_name=name,
|
||||
points_selector=Filter(must=[
|
||||
FieldCondition(key="source_id", match=MatchAny(any=[source_id]))
|
||||
]),
|
||||
)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
if not self._client:
|
||||
return False
|
||||
try:
|
||||
self._client.get_collections()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ────────── RAG Service ──────────
|
||||
|
||||
class RAGService:
|
||||
def __init__(self, embedding: EmbeddingService, store: VectorStore):
|
||||
self.embedding = embedding
|
||||
self.store = store
|
||||
|
||||
async def index_document(self, project_id: str, source_id: str, content: str,
|
||||
emb_config: dict):
|
||||
"""Chunk, embed, and index a document."""
|
||||
text_chunks = chunk_text(content, 500, 50)
|
||||
dim = emb_config.get("dimensions", 1024)
|
||||
await self.index_chunks(project_id, source_id, text_chunks, emb_config, dim=dim)
|
||||
|
||||
async def index_chunks(self, project_id: str, source_id: str,
|
||||
text_chunks: list[str], emb_config: dict,
|
||||
dim: int = 1024,
|
||||
metadata_list: list[dict] | None = None,
|
||||
on_progress: callable | None = None):
|
||||
"""Embed and index pre-chunked text with optional per-chunk metadata."""
|
||||
self.store.ensure_collection(project_id, dim=dim)
|
||||
total = len(text_chunks)
|
||||
logger.info("index_chunks: %d chunks for source %s (dim=%d)", total, source_id, dim)
|
||||
|
||||
batch: list[dict] = []
|
||||
batch_size = 20 # Insert in batches to reduce Qdrant round-trips
|
||||
|
||||
for i, text in enumerate(text_chunks):
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
vec = await self.embedding.get_embedding(
|
||||
text, emb_config["base_url"], emb_config["model"],
|
||||
emb_config["api_key"], emb_config["provider"],
|
||||
)
|
||||
if not vec:
|
||||
logger.warning("Empty vector for chunk %d, skipping", i)
|
||||
continue
|
||||
chunk_id = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{source_id}-chunk-{i}"))
|
||||
entry: dict = {"id": chunk_id, "source_id": source_id, "text": text, "vector": vec}
|
||||
if metadata_list and i < len(metadata_list):
|
||||
entry["metadata"] = metadata_list[i]
|
||||
batch.append(entry)
|
||||
|
||||
# Flush batch
|
||||
if len(batch) >= batch_size:
|
||||
self.store.insert(project_id, batch)
|
||||
batch = []
|
||||
|
||||
# Report progress
|
||||
if on_progress and (i + 1) % 10 == 0:
|
||||
on_progress(i + 1, total)
|
||||
|
||||
# Final batch
|
||||
if batch:
|
||||
self.store.insert(project_id, batch)
|
||||
|
||||
inserted = total # approximate
|
||||
logger.info("Inserted vectors for source %s (%d chunks)", source_id, inserted)
|
||||
|
||||
async def search_context(self, project_id: str, question: str, top_k: int,
|
||||
emb_config: dict, file_ids: list[str] | None = None) -> list[dict]:
|
||||
"""Search for relevant chunks."""
|
||||
query_vec = await self.embedding.get_embedding(
|
||||
question, emb_config["base_url"], emb_config["model"],
|
||||
emb_config["api_key"], emb_config["provider"],
|
||||
)
|
||||
return self.store.search(project_id, query_vec, top_k, file_ids)
|
||||
|
||||
|
||||
# ────────── SQLite fallback keyword search ──────────
|
||||
|
||||
def search_text_chunks_keyword(session, project_id: str, query: str,
|
||||
file_ids: list[str] | None, top_k: int) -> list[dict]:
|
||||
"""Keyword-based fallback when vector search is unavailable."""
|
||||
from models import TextChunk
|
||||
|
||||
keywords = _extract_keywords(query)
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
q = session.query(TextChunk).filter(TextChunk.project_id == project_id)
|
||||
if file_ids:
|
||||
q = q.filter(TextChunk.source_id.in_(file_ids))
|
||||
|
||||
from sqlalchemy import or_
|
||||
conditions = [TextChunk.content.ilike(f"%{kw}%") for kw in keywords if len(kw) >= 2]
|
||||
if not conditions:
|
||||
return []
|
||||
q = q.filter(or_(*conditions)).order_by(TextChunk.chunk_idx).limit(top_k)
|
||||
|
||||
return [{"text": c.content, "source_id": c.source_id} for c in q.all()]
|
||||
|
||||
|
||||
def _extract_keywords(query: str) -> list[str]:
|
||||
import re
|
||||
stop = {"的", "了", "是", "在", "和", "与", "对", "有", "不", "这", "那", "我", "你"}
|
||||
parts = re.split(r'[,。、?!,.\?!\s\n\t]+', query)
|
||||
return [p.strip() for p in parts if len(p.strip()) >= 2 and p.strip() not in stop]
|
||||
|
||||
|
||||
# ── Singletons ──
|
||||
|
||||
embedding_service = EmbeddingService()
|
||||
vector_store = VectorStore()
|
||||
rag_service = RAGService(embedding_service, vector_store)
|
||||
Reference in New Issue
Block a user