fix(history): block compact during active runs (#2635)

This commit is contained in:
Ocean Bennett
2026-06-04 15:50:16 -04:00
committed by GitHub
parent 67782e684e
commit e69298888b
3 changed files with 135 additions and 7 deletions

View File

@@ -10,7 +10,12 @@ from fastapi import APIRouter, Request, HTTPException
from core.models import ChatMessage from core.models import ChatMessage
from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession
from src.topic_analyzer import analyze_topics from src.topic_analyzer import analyze_topics
from routes.session_routes import _verify_session_owner from routes.session_routes import (
_message_role,
_message_text,
_reject_compact_during_active_run,
_verify_session_owner,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -521,6 +526,7 @@ def setup_history_routes(session_manager) -> APIRouter:
session = session_manager.get_session(session_id) session = session_manager.get_session(session_id)
except KeyError: except KeyError:
raise HTTPException(404, "Session not found") raise HTTPException(404, "Session not found")
_reject_compact_during_active_run(session_id)
try: try:
from src.model_context import estimate_tokens, get_context_length from src.model_context import estimate_tokens, get_context_length
@@ -543,8 +549,8 @@ def setup_history_routes(session_manager) -> APIRouter:
# Build text to summarize # Build text to summarize
convo_text = "\n".join( convo_text = "\n".join(
f"{(m.role if isinstance(m, ChatMessage) else m.get('role', '')).upper()}: " f"{_message_role(m).upper()}: "
f"{((m.content if isinstance(m, ChatMessage) else m.get('content')) or '')[:2000]}" f"{_message_text(m)[:2000]}"
for m in older for m in older
) )

View File

@@ -57,6 +57,40 @@ def _content_to_text(content) -> str:
return "" return ""
def _message_role(message) -> str:
if isinstance(message, ChatMessage):
return message.role or ""
if isinstance(message, dict):
return message.get("role", "") or ""
return getattr(message, "role", "") or ""
def _message_text(message) -> str:
if isinstance(message, ChatMessage):
content = message.content
elif isinstance(message, dict):
content = message.get("content")
else:
content = getattr(message, "content", None)
return _content_to_text(content)
def _message_metadata(message) -> dict:
if isinstance(message, ChatMessage):
metadata = message.metadata
elif isinstance(message, dict):
metadata = message.get("metadata")
else:
metadata = getattr(message, "metadata", None)
return metadata if isinstance(metadata, dict) else {}
def _reject_compact_during_active_run(session_id: str) -> None:
from src import agent_runs
if agent_runs.is_active(session_id):
raise HTTPException(409, "Session has an active run; try compacting after it finishes")
def _verify_session_owner(request: Request, session_id: str, session_manager=None): def _verify_session_owner(request: Request, session_id: str, session_manager=None):
"""Verify the current user owns the session. Raises 404 if not. """Verify the current user owns the session. Raises 404 if not.
@@ -872,6 +906,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
session = session_manager.get_session(session_id) session = session_manager.get_session(session_id)
except KeyError: except KeyError:
raise HTTPException(404, f"Session {session_id} not found") raise HTTPException(404, f"Session {session_id} not found")
_reject_compact_during_active_run(session_id)
history = list(session.history or []) history = list(session.history or [])
if len(history) < 6: if len(history) < 6:
@@ -897,7 +932,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
prior_compactions = sum( prior_compactions = sum(
1 for m in history 1 for m in history
if (m.metadata or {}).get("compacted") or "[Conversation summary" in (m.content or "") if _message_metadata(m).get("compacted") or "[Conversation summary" in _message_text(m)
) )
prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace( prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace(
"{count}", str(len(older)) "{count}", str(len(older))
@@ -905,7 +940,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
"{n}", str(prior_compactions + 1) "{n}", str(prior_compactions + 1)
) )
convo_text = "\n".join( convo_text = "\n".join(
f"{m.role.upper()}: {(m.content or '')[:2000]}" f"{_message_role(m).upper()}: {_message_text(m)[:2000]}"
for m in older for m in older
) )
try: try:

View File

@@ -1,10 +1,11 @@
from types import SimpleNamespace from types import SimpleNamespace
from fastapi import FastAPI from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from core.models import ChatMessage from core.models import ChatMessage
import routes.history_routes as history_routes import routes.history_routes as history_routes
import routes.session_routes as session_routes
class _FakeQuery: class _FakeQuery:
@@ -53,6 +54,7 @@ class _FakeSessionManager:
def __init__(self, session): def __init__(self, session):
self.session = session self.session = session
self.saved = False self.saved = False
self.replaced_messages = None
def get_session(self, session_id): def get_session(self, session_id):
if session_id != self.session.id: if session_id != self.session.id:
@@ -62,6 +64,14 @@ class _FakeSessionManager:
def save_sessions(self): def save_sessions(self):
self.saved = True self.saved = True
def replace_messages(self, session_id, messages):
if session_id != self.session.id:
return False
self.replaced_messages = list(messages)
self.session.history = list(messages)
self.session.message_count = len(messages)
return True
class _FakeSession: class _FakeSession:
id = "session-1" id = "session-1"
@@ -91,11 +101,13 @@ def _compact_prompt_for(monkeypatch, history):
monkeypatch.setattr(history_routes, "_verify_session_owner", lambda request, session_id: None) monkeypatch.setattr(history_routes, "_verify_session_owner", lambda request, session_id: None)
monkeypatch.setattr(history_routes, "SessionLocal", lambda: _FakeDb()) monkeypatch.setattr(history_routes, "SessionLocal", lambda: _FakeDb())
import src.agent_runs as agent_runs
import src.endpoint_resolver as endpoint_resolver import src.endpoint_resolver as endpoint_resolver
import src.llm_core as llm_core import src.llm_core as llm_core
import src.model_context as model_context import src.model_context as model_context
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind: (None, None, {})) monkeypatch.setattr(agent_runs, "is_active", lambda session_id: False)
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {}))
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async) monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async)
monkeypatch.setattr(model_context, "estimate_tokens", lambda messages: 100) monkeypatch.setattr(model_context, "estimate_tokens", lambda messages: 100)
monkeypatch.setattr(model_context, "get_context_length", lambda endpoint_url, model: 1000) monkeypatch.setattr(model_context, "get_context_length", lambda endpoint_url, model: 1000)
@@ -113,6 +125,40 @@ def _compact_prompt_for(monkeypatch, history):
return captured["messages"][1]["content"] return captured["messages"][1]["content"]
def _registered_compact_response(monkeypatch, history, active_run=False):
captured = {}
async def fake_llm_call_async(endpoint_url, model, messages, **kwargs):
captured["messages"] = messages
return "Summary text"
monkeypatch.setattr(
session_routes,
"router",
APIRouter(prefix="/api", tags=["sessions"]),
)
monkeypatch.setattr(session_routes, "_verify_session_owner", lambda request, session_id: None)
monkeypatch.setattr(history_routes, "_verify_session_owner", lambda request, session_id: None)
monkeypatch.setattr(history_routes, "SessionLocal", lambda: _FakeDb())
import src.agent_runs as agent_runs
import src.endpoint_resolver as endpoint_resolver
import src.llm_core as llm_core
monkeypatch.setattr(agent_runs, "is_active", lambda session_id: active_run)
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {}))
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async)
session = _FakeSession(history)
manager = _FakeSessionManager(session)
app = FastAPI()
app.include_router(session_routes.setup_session_routes(manager, {}))
app.include_router(history_routes.setup_history_routes(manager))
response = TestClient(app).post("/api/session/session-1/compact")
return response, captured, manager
def test_manual_compact_tolerates_chatmessage_with_none_content(monkeypatch): def test_manual_compact_tolerates_chatmessage_with_none_content(monkeypatch):
compact_prompt = _compact_prompt_for( compact_prompt = _compact_prompt_for(
monkeypatch, monkeypatch,
@@ -143,3 +189,44 @@ def test_manual_compact_tolerates_dict_message_with_none_content(monkeypatch):
) )
assert "ASSISTANT: None" not in compact_prompt assert "ASSISTANT: None" not in compact_prompt
assert "ASSISTANT: " in compact_prompt assert "ASSISTANT: " in compact_prompt
def test_registered_manual_compact_route_tolerates_none_content(monkeypatch):
response, captured, manager = _registered_compact_response(
monkeypatch,
[
ChatMessage(role="user", content="start"),
ChatMessage(role="assistant", content=None),
ChatMessage(role="tool", content="tool result"),
ChatMessage(role="assistant", content="done"),
ChatMessage(role="user", content="next"),
ChatMessage(role="assistant", content="final"),
],
)
assert response.status_code == 200
assert response.json()["ok"] is True
compact_prompt = captured["messages"][1]["content"]
assert "ASSISTANT: None" not in compact_prompt
assert "ASSISTANT: " in compact_prompt
assert manager.replaced_messages is not None
def test_registered_manual_compact_route_rejects_active_agent_run(monkeypatch):
response, captured, manager = _registered_compact_response(
monkeypatch,
[
ChatMessage(role="user", content="start"),
ChatMessage(role="assistant", content="tool call"),
ChatMessage(role="tool", content="tool result"),
ChatMessage(role="assistant", content="done"),
ChatMessage(role="user", content="next"),
ChatMessage(role="assistant", content="final"),
],
active_run=True,
)
assert response.status_code == 409
assert "active run" in response.text
assert captured == {}
assert manager.replaced_messages is None