diff --git a/scripts/update_database.py b/scripts/update_database.py index 80f1489..195b0ba 100644 --- a/scripts/update_database.py +++ b/scripts/update_database.py @@ -166,116 +166,3 @@ def update_database(): if __name__ == "__main__": update_database() -""" -update_database.py - -This script updates the database schema by adding new columns to the sessions table -if they don't already exist. It uses raw SQL ALTER TABLE statements to modify -the existing SQLite database. - -The following columns are added: -- last_accessed (DateTime): Set to created_at for existing records -- is_important (Boolean): Set to False for existing records -- message_count (Integer): Calculated from the number of messages in chat_messages table - -Usage: - python update_database.py -""" - -import os -from datetime import datetime -from sqlalchemy import create_engine, text -from database import DATABASE_URL, SessionLocal - -def update_database(): - """Update the database schema and populate new columns.""" - # Create engine from DATABASE_URL - engine = create_engine(DATABASE_URL) - - # Start a transaction - db = SessionLocal() - try: - # Add last_accessed column if it doesn't exist - try: - with engine.connect() as conn: - conn.execute(text("ALTER TABLE sessions ADD COLUMN last_accessed DATETIME")) - conn.commit() - print("Added last_accessed column to sessions table") - except Exception as e: - if "duplicate column name" in str(e).lower(): - print("last_accessed column already exists") - else: - print(f"Error adding last_accessed column: {e}") - - # Add is_important column if it doesn't exist - try: - with engine.connect() as conn: - conn.execute(text("ALTER TABLE sessions ADD COLUMN is_important BOOLEAN DEFAULT FALSE")) - conn.commit() - print("Added is_important column to sessions table") - except Exception as e: - if "duplicate column name" in str(e).lower(): - print("is_important column already exists") - else: - print(f"Error adding is_important column: {e}") - - # Add message_count column if it doesn't exist - try: - with engine.connect() as conn: - conn.execute(text("ALTER TABLE sessions ADD COLUMN message_count INTEGER DEFAULT 0")) - conn.commit() - print("Added message_count column to sessions table") - except Exception as e: - if "duplicate column name" in str(e).lower(): - print("message_count column already exists") - else: - print(f"Error adding message_count column: {e}") - - # Populate last_accessed with created_at for existing records where last_accessed is NULL - print("Populating last_accessed column...") - with engine.connect() as conn: - conn.execute(text(""" - UPDATE sessions - SET last_accessed = created_at - WHERE last_accessed IS NULL - """)) - conn.commit() - - # Populate is_important with FALSE for existing records where is_important is NULL - print("Populating is_important column...") - with engine.connect() as conn: - conn.execute(text(""" - UPDATE sessions - SET is_important = 0 - WHERE is_important IS NULL - """)) - conn.commit() - - # Calculate and populate message_count from chat_messages table - print("Calculating and populating message_count column...") - with engine.connect() as conn: - # First, set all message_count to 0 - conn.execute(text("UPDATE sessions SET message_count = 0")) - - # Then, count messages for each session and update - conn.execute(text(""" - UPDATE sessions - SET message_count = ( - SELECT COUNT(*) - FROM chat_messages - WHERE chat_messages.session_id = sessions.id - ) - """)) - conn.commit() - - print("Database update completed successfully!") - - except Exception as e: - print(f"Error updating database: {e}") - db.rollback() - raise - finally: - db.close() - -if __name__ == "__main__": - update_database() diff --git a/tests/test_update_database_script.py b/tests/test_update_database_script.py new file mode 100644 index 0000000..3a17f0b --- /dev/null +++ b/tests/test_update_database_script.py @@ -0,0 +1,8 @@ +from pathlib import Path + + +def test_update_database_has_single_main_guard(): + script = Path(__file__).resolve().parent.parent / "scripts" / "update_database.py" + text = script.read_text() + + assert text.count('if __name__ == "__main__":') == 1