diff --git a/core/database.py b/core/database.py index 8f21f53..293a303 100644 --- a/core/database.py +++ b/core/database.py @@ -1,7 +1,9 @@ import os import logging +import sqlite3 from datetime import datetime -from sqlalchemy import create_engine, Column, String, Text, Boolean, DateTime, Integer, ForeignKey, JSON, Index, func, text +from sqlalchemy import event, create_engine, Column, String, Text, Boolean, DateTime, Integer, ForeignKey, JSON, Index, func, text +from sqlalchemy.engine import Engine from sqlalchemy.types import TypeDecorator from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import relationship, sessionmaker, backref @@ -34,6 +36,18 @@ engine = create_engine( SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# Listening on the Engine class ensures this listener fires for all Engine +# instances created within the process, not just the primary application engine. +# The isinstance(sqlite3.Connection) check ensures that this PRAGMA foreign_keys=ON +# configuration remains a no-op when using non-SQLite database backends. +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + if isinstance(dbapi_connection, sqlite3.Connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + class EncryptedText(TypeDecorator): """Text column transparently encrypted at rest via src.secret_storage. @@ -1484,6 +1498,10 @@ def _migrate_seed_email_account(): logging.getLogger(__name__).warning(f"seed email account migration: {e}") +# WARNING: Foreign-key enforcement is enabled globally for all SQLite connections. +# Any future migrations or schema changes that temporarily violate foreign-key +# constraints will fail. To perform such operations, foreign_keys must be +# temporarily disabled around the migration workflow. def init_db(): """ Initialize the database by creating all tables. diff --git a/tests/test_sqlite_foreign_keys.py b/tests/test_sqlite_foreign_keys.py new file mode 100644 index 0000000..c3df88c --- /dev/null +++ b/tests/test_sqlite_foreign_keys.py @@ -0,0 +1,38 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from core.database import Base, Session, ChatMessage +from datetime import datetime + +def test_sqlite_foreign_keys_cascade(): + engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + Base.metadata.create_all(bind=engine) + + TestSessionLocal = sessionmaker(bind=engine) + db = TestSessionLocal() + + session_id = "test-session-123" + s = Session( + id=session_id, + name="Test Session", + endpoint_url="http://localhost:8000", + model="gpt-4", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + m = ChatMessage(id="test-msg-123", session_id=session_id, role="user", content="test message") + + db.add(s) + db.add(m) + db.commit() + + assert db.query(Session).count() == 1 + assert db.query(ChatMessage).count() == 1 + + db.query(Session).filter(Session.id == session_id).delete() + db.commit() + + assert db.query(ChatMessage).count() == 0 + + db.close() +