Odysseus v1.0
This commit is contained in:
558
core/session_manager.py
Normal file
558
core/session_manager.py
Normal file
@@ -0,0 +1,558 @@
|
||||
# core/session_manager.py
|
||||
"""
|
||||
Session management — all session business logic and DB operations.
|
||||
|
||||
This is the single place that handles:
|
||||
- Loading/saving sessions to database
|
||||
- Adding messages to sessions
|
||||
- Session lifecycle (create, archive, delete)
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal
|
||||
from .models import Session, ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages chat sessions with database persistence.
|
||||
|
||||
Usage:
|
||||
manager = SessionManager()
|
||||
session = manager.create_session(id, name, url, model)
|
||||
manager.add_message(session.id, ChatMessage("user", "hello"))
|
||||
session = manager.get_session(session_id)
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_file: str = None):
|
||||
# sessions_file kept for backward compat, not used
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
self.load_sessions()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_sessions(self):
|
||||
"""Load recent session METADATA from the database — messages are
|
||||
hydrated on demand by `get_session`. Previously this walked every
|
||||
message of every session into RAM at boot, which on a long-running
|
||||
personal-server box could be tens of thousands of rows held forever
|
||||
in `self.sessions`.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_sessions = db.query(DbSession).filter(
|
||||
DbSession.archived == False,
|
||||
DbSession.message_count > 0,
|
||||
).order_by(DbSession.last_accessed.desc()).limit(100).all()
|
||||
|
||||
loaded_count = 0
|
||||
for db_session in db_sessions:
|
||||
try:
|
||||
session = self._db_to_session_meta(db_session)
|
||||
if session is not None:
|
||||
self.sessions[db_session.id] = session
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {db_session.id}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {loaded_count} session(s) (metadata only)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading sessions: {e}")
|
||||
self.sessions = {}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _db_to_session_meta(self, db_session: DbSession) -> Optional[Session]:
|
||||
"""Build a Session with empty history. `get_session` will hydrate
|
||||
messages from the DB on first read."""
|
||||
headers = db_session.headers
|
||||
if isinstance(headers, str):
|
||||
try:
|
||||
headers = json.loads(headers)
|
||||
except json.JSONDecodeError:
|
||||
headers = {}
|
||||
session = Session(
|
||||
id=db_session.id,
|
||||
name=db_session.name,
|
||||
endpoint_url=db_session.endpoint_url,
|
||||
model=db_session.model,
|
||||
rag=db_session.rag,
|
||||
archived=db_session.archived,
|
||||
headers=headers,
|
||||
history=[],
|
||||
owner=getattr(db_session, "owner", None),
|
||||
is_important=getattr(db_session, "is_important", False) or False,
|
||||
)
|
||||
session.message_count = getattr(db_session, "message_count", 0) or 0
|
||||
return session
|
||||
|
||||
def _db_to_session(self, db_session: DbSession, db) -> Optional[Session]:
|
||||
"""Convert a database session to a Session object."""
|
||||
history = []
|
||||
|
||||
# Try relationship first, then direct query
|
||||
if db_session.messages:
|
||||
for db_msg in db_session.messages:
|
||||
meta = json.loads(db_msg.meta_data) if db_msg.meta_data else {}
|
||||
if meta is None: meta = {}
|
||||
meta['_db_id'] = db_msg.id
|
||||
history.append(ChatMessage(
|
||||
role=db_msg.role,
|
||||
content=db_msg.content,
|
||||
metadata=meta,
|
||||
))
|
||||
else:
|
||||
db_messages = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.session_id == db_session.id
|
||||
).order_by(DbChatMessage.timestamp).all()
|
||||
|
||||
for db_msg in db_messages:
|
||||
meta = json.loads(db_msg.meta_data) if db_msg.meta_data else {}
|
||||
if meta is None: meta = {}
|
||||
meta['_db_id'] = db_msg.id
|
||||
history.append(ChatMessage(
|
||||
role=db_msg.role,
|
||||
content=db_msg.content,
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
if not history:
|
||||
return None
|
||||
|
||||
# Parse headers
|
||||
headers = db_session.headers
|
||||
if isinstance(headers, str):
|
||||
try:
|
||||
headers = json.loads(headers)
|
||||
except json.JSONDecodeError:
|
||||
headers = {}
|
||||
|
||||
session = Session(
|
||||
id=db_session.id,
|
||||
name=db_session.name,
|
||||
endpoint_url=db_session.endpoint_url,
|
||||
model=db_session.model,
|
||||
rag=db_session.rag,
|
||||
archived=db_session.archived,
|
||||
headers=headers,
|
||||
history=history,
|
||||
owner=getattr(db_session, 'owner', None),
|
||||
is_important=getattr(db_session, 'is_important', False) or False,
|
||||
)
|
||||
|
||||
session.message_count = getattr(db_session, 'message_count', len(history))
|
||||
return session
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_message(self, session_id: str, message: ChatMessage):
|
||||
"""
|
||||
Add a message to a session and persist to database.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
message: ChatMessage to add
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
session.history.append(message)
|
||||
session.message_count = len(session.history)
|
||||
|
||||
self._persist_message(session_id, message)
|
||||
|
||||
def _persist_message(self, session_id: str, message: ChatMessage):
|
||||
"""Persist a single message to the database."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
msg_id = str(uuid.uuid4())
|
||||
db_message = DbChatMessage(
|
||||
id=msg_id,
|
||||
session_id=session_id,
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
meta_data=json.dumps(message.metadata) if message.metadata else None
|
||||
)
|
||||
db.add(db_message)
|
||||
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.message_count = len(self.sessions.get(session_id, {}).history) if session_id in self.sessions else 0
|
||||
_now = datetime.now(timezone.utc)
|
||||
db_session.last_accessed = _now
|
||||
# Clean "last conversation" timestamp — only bumped here on a
|
||||
# real message persist, so it powers an accurate "Last active"
|
||||
# sort that ignores renames / model swaps / mere opens.
|
||||
db_session.last_message_at = _now
|
||||
|
||||
db.commit()
|
||||
|
||||
# Store DB ID on the in-memory message for edit/delete by ID
|
||||
if message.metadata is None:
|
||||
message.metadata = {}
|
||||
message.metadata['_db_id'] = msg_id
|
||||
|
||||
logger.debug(f"Persisted message to session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error persisting message: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def truncate_messages(self, session_id: str, keep_count: int) -> bool:
|
||||
"""Truncate session history, keeping only the first `keep_count` messages."""
|
||||
session = self.get_session(session_id)
|
||||
|
||||
if keep_count < 0:
|
||||
return False
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_messages = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.session_id == session_id
|
||||
).order_by(DbChatMessage.timestamp).all()
|
||||
|
||||
deleted = 0
|
||||
for msg in db_messages[keep_count:]:
|
||||
db.delete(msg)
|
||||
deleted += 1
|
||||
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.message_count = keep_count
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Update in-memory
|
||||
session.history = session.history[:keep_count]
|
||||
|
||||
logger.info(f"Truncated session {session_id} to {keep_count} messages")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error truncating session: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def replace_messages(self, session_id: str, messages: list) -> bool:
|
||||
"""Replace a session's persisted and in-memory history atomically."""
|
||||
session = self.get_session(session_id)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.query(DbChatMessage).filter(DbChatMessage.session_id == session_id).delete()
|
||||
now = datetime.now(timezone.utc)
|
||||
for i, message in enumerate(messages):
|
||||
msg_id = str(uuid.uuid4())
|
||||
db_message = DbChatMessage(
|
||||
id=msg_id,
|
||||
session_id=session_id,
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
meta_data=json.dumps(message.metadata) if message.metadata else None,
|
||||
timestamp=now + timedelta(microseconds=i),
|
||||
)
|
||||
db.add(db_message)
|
||||
if message.metadata is None:
|
||||
message.metadata = {}
|
||||
message.metadata["_db_id"] = msg_id
|
||||
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.message_count = len(messages)
|
||||
db_session.updated_at = now
|
||||
db_session.last_accessed = now
|
||||
db_session.last_message_at = now
|
||||
|
||||
db.commit()
|
||||
session.history = list(messages)
|
||||
session.message_count = len(messages)
|
||||
logger.info("Replaced session %s history with %d messages", session_id, len(messages))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error replacing session history: %s", e)
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_session(self, session_id: str) -> Session:
|
||||
"""Get a session by ID, loading from DB if needed.
|
||||
|
||||
Sessions seeded by `load_sessions` start with empty history. The
|
||||
first read here hydrates them with the message rows.
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
self._load_session_from_db(session_id)
|
||||
else:
|
||||
cached = self.sessions[session_id]
|
||||
# Lazy hydrate: metadata-only entries get their messages on first read.
|
||||
if not cached.history and getattr(cached, "message_count", 0) > 0:
|
||||
self._load_session_from_db(session_id)
|
||||
|
||||
# Update last_accessed
|
||||
self._touch_session(session_id)
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _load_session_from_db(self, session_id: str):
|
||||
"""Hydrate a single session (with messages) from the database."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session is None:
|
||||
raise KeyError(f"Session {session_id} not found")
|
||||
|
||||
session = self._db_to_session(db_session, db)
|
||||
if session:
|
||||
self.sessions[session_id] = session
|
||||
else:
|
||||
# No messages — fall back to metadata-only entry so callers
|
||||
# don't crash on KeyError for empty sessions.
|
||||
meta = self._db_to_session_meta(db_session)
|
||||
if meta is None:
|
||||
raise KeyError(f"Session {session_id} could not be loaded")
|
||||
self.sessions[session_id] = meta
|
||||
|
||||
except KeyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {session_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _touch_session(self, session_id: str):
|
||||
"""Update last_accessed timestamp."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.last_accessed = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating last_accessed: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session_id: str,
|
||||
name: str,
|
||||
endpoint_url: str,
|
||||
model: str,
|
||||
rag: bool = False,
|
||||
owner: str = None
|
||||
) -> Session:
|
||||
"""Create a new session and save to database."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = DbSession(
|
||||
id=session_id,
|
||||
name=name,
|
||||
endpoint_url=endpoint_url,
|
||||
model=model,
|
||||
rag=rag,
|
||||
headers={},
|
||||
owner=owner,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(db_session)
|
||||
db.commit()
|
||||
|
||||
session = Session(
|
||||
id=session_id,
|
||||
name=name,
|
||||
endpoint_url=endpoint_url,
|
||||
model=model,
|
||||
rag=rag,
|
||||
headers={},
|
||||
owner=owner,
|
||||
)
|
||||
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating session: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Permanently delete a session and all its messages."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Detach documents so they survive as orphans in the library
|
||||
db.query(DbDocument).filter(DbDocument.session_id == session_id).update(
|
||||
{DbDocument.session_id: None}, synchronize_session=False
|
||||
)
|
||||
|
||||
# Delete messages
|
||||
db.query(DbChatMessage).filter(DbChatMessage.session_id == session_id).delete()
|
||||
|
||||
# Delete session
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db.delete(db_session)
|
||||
db.commit()
|
||||
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
|
||||
logger.info(f"Deleted session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting session: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session updates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_session_name(self, session_id: str, name: str):
|
||||
"""Update session name."""
|
||||
if session_id not in self.sessions:
|
||||
return
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.name = name
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
self.sessions[session_id].name = name
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating session name: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def archive_session(self, session_id: str):
|
||||
"""Archive a session."""
|
||||
if session_id not in self.sessions:
|
||||
return
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.archived = True
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
self.sessions[session_id].archived = True
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error archiving session: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def mark_important(self, session_id: str, important: bool = True):
|
||||
"""Mark session as important."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.is_important = important
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
if session_id in self.sessions:
|
||||
self.sessions[session_id].is_important = important
|
||||
else:
|
||||
raise KeyError(f"Session {session_id} not found")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error marking session important: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_sessions_for_user(self, username: Optional[str] = None) -> Dict[str, Session]:
|
||||
"""Return sessions for a specific user (or all if username is None)."""
|
||||
if username is None:
|
||||
return self.sessions
|
||||
return {
|
||||
sid: s for sid, s in self.sessions.items()
|
||||
if s.owner == username
|
||||
}
|
||||
|
||||
def save_sessions(self):
|
||||
"""No-op for DB compatibility."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def cleanup_empty_sessions(self, auto_archive_days: int = 30) -> dict:
|
||||
"""Clean up empty and old sessions."""
|
||||
db = SessionLocal()
|
||||
stats = {'deleted_empty': 0, 'archived_old': 0, 'total_checked': 0}
|
||||
|
||||
try:
|
||||
all_sessions = db.query(DbSession).all()
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=auto_archive_days)
|
||||
|
||||
for db_session in all_sessions:
|
||||
stats['total_checked'] += 1
|
||||
|
||||
# Delete empty sessions
|
||||
if db_session.message_count == 0:
|
||||
if db_session.id in self.sessions:
|
||||
del self.sessions[db_session.id]
|
||||
db.delete(db_session)
|
||||
stats['deleted_empty'] += 1
|
||||
|
||||
# Archive old sessions
|
||||
elif (not db_session.archived and
|
||||
db_session.last_accessed and
|
||||
db_session.last_accessed < cutoff_date and
|
||||
db_session.message_count > 0 and
|
||||
not getattr(db_session, 'is_important', False)):
|
||||
db_session.archived = True
|
||||
stats['archived_old'] += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Cleanup: {stats['deleted_empty']} deleted, {stats['archived_old']} archived")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user