Files
2026-04-16 10:01:11 +08:00

111 lines
4.0 KiB
Python

"""Database manager: global DB + per-project isolated DBs."""
from __future__ import annotations
import os
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from models import (
Base, LLMProvider, VectorDBConfig, EmbeddingModelConfig, Project,
SourceFile, ChatMessage, TemplateChapter, TextChunk, DeliveryStandard,
)
# Global tables
GLOBAL_TABLES = {LLMProvider.__table__, VectorDBConfig.__table__,
EmbeddingModelConfig.__table__, Project.__table__}
# Project tables
PROJECT_TABLES = {SourceFile.__table__, ChatMessage.__table__,
TemplateChapter.__table__, TextChunk.__table__,
DeliveryStandard.__table__}
def _data_dir() -> Path:
home = Path.home()
d = home / ".engimind"
d.mkdir(parents=True, exist_ok=True)
return d
class DatabaseManager:
"""Manages global + project SQLite databases."""
def __init__(self):
self._global_engine = None
self._global_session_factory = None
self._project_engine = None
self._project_session_factory = None
self.current_project_id: str = ""
# ── Global DB ──
def init_global(self):
db_path = _data_dir() / "global.db"
self._global_engine = create_engine(f"sqlite:///{db_path}", echo=False)
# Create only global tables
Base.metadata.create_all(self._global_engine, tables=list(GLOBAL_TABLES))
self._global_session_factory = sessionmaker(bind=self._global_engine)
self._migrate_global()
self._seed_defaults()
def global_session(self) -> Session:
return self._global_session_factory()
def _migrate_global(self):
"""Add missing columns to existing tables (SQLite ALTER TABLE)."""
from sqlalchemy import text, inspect
insp = inspect(self._global_engine)
cols = {c["name"] for c in insp.get_columns("embedding_model_configs")}
if "dimensions" not in cols:
with self._global_engine.begin() as conn:
conn.execute(text(
"ALTER TABLE embedding_model_configs ADD COLUMN dimensions INTEGER DEFAULT 1024"
))
def _seed_defaults(self):
with self.global_session() as s:
if s.query(LLMProvider).count() == 0:
s.add_all([
LLMProvider(id="cfg1", name="DeepSeek Cloud", provider="DeepSeek",
base_url="https://api.deepseek.com", api_key="",
model_id="deepseek-reasoner", enabled=True),
LLMProvider(id="cfg2", name="Ollama: Local", provider="Ollama",
base_url="http://localhost:11434", api_key="",
model_id="qwen2.5:32b", enabled=True),
])
s.commit()
if s.query(VectorDBConfig).count() == 0:
s.add(VectorDBConfig(id=1, endpoint="http://localhost:6333",
api_key="", status="disconnected"))
s.commit()
# ── Project DB ──
def open_project(self, project_id: str, db_path: str):
self.close_project()
self._project_engine = create_engine(f"sqlite:///{db_path}", echo=False)
Base.metadata.create_all(self._project_engine, tables=list(PROJECT_TABLES))
self._project_session_factory = sessionmaker(bind=self._project_engine)
self.current_project_id = project_id
def close_project(self):
if self._project_engine:
self._project_engine.dispose()
self._project_engine = None
self._project_session_factory = None
self.current_project_id = ""
def project_session(self) -> Session | None:
if self._project_session_factory is None:
return None
return self._project_session_factory()
def projects_base_dir(self) -> Path:
d = _data_dir() / "projects"
d.mkdir(parents=True, exist_ok=True)
return d
# Singleton
db = DatabaseManager()