Remove duplicate update database body (#1584)
This commit is contained in:
@@ -166,116 +166,3 @@ def update_database():
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_database()
|
||||
"""
|
||||
update_database.py
|
||||
|
||||
This script updates the database schema by adding new columns to the sessions table
|
||||
if they don't already exist. It uses raw SQL ALTER TABLE statements to modify
|
||||
the existing SQLite database.
|
||||
|
||||
The following columns are added:
|
||||
- last_accessed (DateTime): Set to created_at for existing records
|
||||
- is_important (Boolean): Set to False for existing records
|
||||
- message_count (Integer): Calculated from the number of messages in chat_messages table
|
||||
|
||||
Usage:
|
||||
python update_database.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from sqlalchemy import create_engine, text
|
||||
from database import DATABASE_URL, SessionLocal
|
||||
|
||||
def update_database():
|
||||
"""Update the database schema and populate new columns."""
|
||||
# Create engine from DATABASE_URL
|
||||
engine = create_engine(DATABASE_URL)
|
||||
|
||||
# Start a transaction
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Add last_accessed column if it doesn't exist
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("ALTER TABLE sessions ADD COLUMN last_accessed DATETIME"))
|
||||
conn.commit()
|
||||
print("Added last_accessed column to sessions table")
|
||||
except Exception as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
print("last_accessed column already exists")
|
||||
else:
|
||||
print(f"Error adding last_accessed column: {e}")
|
||||
|
||||
# Add is_important column if it doesn't exist
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("ALTER TABLE sessions ADD COLUMN is_important BOOLEAN DEFAULT FALSE"))
|
||||
conn.commit()
|
||||
print("Added is_important column to sessions table")
|
||||
except Exception as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
print("is_important column already exists")
|
||||
else:
|
||||
print(f"Error adding is_important column: {e}")
|
||||
|
||||
# Add message_count column if it doesn't exist
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("ALTER TABLE sessions ADD COLUMN message_count INTEGER DEFAULT 0"))
|
||||
conn.commit()
|
||||
print("Added message_count column to sessions table")
|
||||
except Exception as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
print("message_count column already exists")
|
||||
else:
|
||||
print(f"Error adding message_count column: {e}")
|
||||
|
||||
# Populate last_accessed with created_at for existing records where last_accessed is NULL
|
||||
print("Populating last_accessed column...")
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("""
|
||||
UPDATE sessions
|
||||
SET last_accessed = created_at
|
||||
WHERE last_accessed IS NULL
|
||||
"""))
|
||||
conn.commit()
|
||||
|
||||
# Populate is_important with FALSE for existing records where is_important is NULL
|
||||
print("Populating is_important column...")
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("""
|
||||
UPDATE sessions
|
||||
SET is_important = 0
|
||||
WHERE is_important IS NULL
|
||||
"""))
|
||||
conn.commit()
|
||||
|
||||
# Calculate and populate message_count from chat_messages table
|
||||
print("Calculating and populating message_count column...")
|
||||
with engine.connect() as conn:
|
||||
# First, set all message_count to 0
|
||||
conn.execute(text("UPDATE sessions SET message_count = 0"))
|
||||
|
||||
# Then, count messages for each session and update
|
||||
conn.execute(text("""
|
||||
UPDATE sessions
|
||||
SET message_count = (
|
||||
SELECT COUNT(*)
|
||||
FROM chat_messages
|
||||
WHERE chat_messages.session_id = sessions.id
|
||||
)
|
||||
"""))
|
||||
conn.commit()
|
||||
|
||||
print("Database update completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating database: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_database()
|
||||
|
||||
8
tests/test_update_database_script.py
Normal file
8
tests/test_update_database_script.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_update_database_has_single_main_guard():
|
||||
script = Path(__file__).resolve().parent.parent / "scripts" / "update_database.py"
|
||||
text = script.read_text()
|
||||
|
||||
assert text.count('if __name__ == "__main__":') == 1
|
||||
Reference in New Issue
Block a user