111 lines
4.0 KiB
Python
111 lines
4.0 KiB
Python
"""Database manager: global DB + per-project isolated DBs."""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
|
|
from models import (
|
|
Base, LLMProvider, VectorDBConfig, EmbeddingModelConfig, Project,
|
|
SourceFile, ChatMessage, TemplateChapter, TextChunk, DeliveryStandard,
|
|
)
|
|
|
|
# Global tables
|
|
GLOBAL_TABLES = {LLMProvider.__table__, VectorDBConfig.__table__,
|
|
EmbeddingModelConfig.__table__, Project.__table__}
|
|
# Project tables
|
|
PROJECT_TABLES = {SourceFile.__table__, ChatMessage.__table__,
|
|
TemplateChapter.__table__, TextChunk.__table__,
|
|
DeliveryStandard.__table__}
|
|
|
|
|
|
def _data_dir() -> Path:
|
|
home = Path.home()
|
|
d = home / ".engimind"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
class DatabaseManager:
|
|
"""Manages global + project SQLite databases."""
|
|
|
|
def __init__(self):
|
|
self._global_engine = None
|
|
self._global_session_factory = None
|
|
self._project_engine = None
|
|
self._project_session_factory = None
|
|
self.current_project_id: str = ""
|
|
|
|
# ── Global DB ──
|
|
|
|
def init_global(self):
|
|
db_path = _data_dir() / "global.db"
|
|
self._global_engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
|
# Create only global tables
|
|
Base.metadata.create_all(self._global_engine, tables=list(GLOBAL_TABLES))
|
|
self._global_session_factory = sessionmaker(bind=self._global_engine)
|
|
self._migrate_global()
|
|
self._seed_defaults()
|
|
|
|
def global_session(self) -> Session:
|
|
return self._global_session_factory()
|
|
|
|
def _migrate_global(self):
|
|
"""Add missing columns to existing tables (SQLite ALTER TABLE)."""
|
|
from sqlalchemy import text, inspect
|
|
insp = inspect(self._global_engine)
|
|
cols = {c["name"] for c in insp.get_columns("embedding_model_configs")}
|
|
if "dimensions" not in cols:
|
|
with self._global_engine.begin() as conn:
|
|
conn.execute(text(
|
|
"ALTER TABLE embedding_model_configs ADD COLUMN dimensions INTEGER DEFAULT 1024"
|
|
))
|
|
|
|
def _seed_defaults(self):
|
|
with self.global_session() as s:
|
|
if s.query(LLMProvider).count() == 0:
|
|
s.add_all([
|
|
LLMProvider(id="cfg1", name="DeepSeek Cloud", provider="DeepSeek",
|
|
base_url="https://api.deepseek.com", api_key="",
|
|
model_id="deepseek-reasoner", enabled=True),
|
|
LLMProvider(id="cfg2", name="Ollama: Local", provider="Ollama",
|
|
base_url="http://localhost:11434", api_key="",
|
|
model_id="qwen2.5:32b", enabled=True),
|
|
])
|
|
s.commit()
|
|
if s.query(VectorDBConfig).count() == 0:
|
|
s.add(VectorDBConfig(id=1, endpoint="http://localhost:6333",
|
|
api_key="", status="disconnected"))
|
|
s.commit()
|
|
|
|
# ── Project DB ──
|
|
|
|
def open_project(self, project_id: str, db_path: str):
|
|
self.close_project()
|
|
self._project_engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
|
Base.metadata.create_all(self._project_engine, tables=list(PROJECT_TABLES))
|
|
self._project_session_factory = sessionmaker(bind=self._project_engine)
|
|
self.current_project_id = project_id
|
|
|
|
def close_project(self):
|
|
if self._project_engine:
|
|
self._project_engine.dispose()
|
|
self._project_engine = None
|
|
self._project_session_factory = None
|
|
self.current_project_id = ""
|
|
|
|
def project_session(self) -> Session | None:
|
|
if self._project_session_factory is None:
|
|
return None
|
|
return self._project_session_factory()
|
|
|
|
def projects_base_dir(self) -> Path:
|
|
d = _data_dir() / "projects"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
# Singleton
|
|
db = DatabaseManager()
|