diff --git a/src/context_compactor.py b/src/context_compactor.py index 2d0b15f..8ed5909 100644 --- a/src/context_compactor.py +++ b/src/context_compactor.py @@ -15,6 +15,26 @@ from core.models import ChatMessage logger = logging.getLogger(__name__) + +def _content_as_text(content: Any) -> str: + """Flatten a message's content to plain text. + + Handles the three shapes that flow through history: a plain string, a + multimodal list of content blocks (vision/image attachments), and None + (assistant turns that carried only native tool_calls persist content as + None). Returns "" for anything without text so callers can safely slice + the result. + """ + if isinstance(content, str): + return content + if isinstance(content, list): + return " ".join( + b.get("text", "") for b in content + if isinstance(b, dict) and b.get("text") + ) + return "" + + COMPACT_THRESHOLD = 0.85 # Trigger compaction at 85% of context window SUMMARY_MAX_TOKENS = 1024 SMALL_CONTEXT_LIMIT = 8192 # Models with context <= this get aggressive trimming @@ -274,7 +294,7 @@ async def maybe_compact( # Build the text to summarize convo_text = "\n".join( - f"{msg['role'].upper()}: {msg.get('content', '')[:2000]}" + f"{msg.get('role', 'user').upper()}: {_content_as_text(msg.get('content'))[:2000]}" for msg in older ) diff --git a/tests/test_context_compactor.py b/tests/test_context_compactor.py index 5a1dfa3..393b4ac 100644 --- a/tests/test_context_compactor.py +++ b/tests/test_context_compactor.py @@ -1,9 +1,12 @@ """Tests for context_compactor.py — constants and prompt templates. Uses mock imports to avoid loading the full app stack.""" +import asyncio import sys from unittest.mock import MagicMock +import pytest + # Mock heavy dependencies before importing for mod in [ 'sqlalchemy', 'sqlalchemy.orm', 'sqlalchemy.ext', 'sqlalchemy.ext.declarative', @@ -14,10 +17,13 @@ for mod in [ if mod not in sys.modules: sys.modules[mod] = MagicMock() +import src.context_compactor as cc from src.context_compactor import ( COMPACT_THRESHOLD, SELF_SUMMARY_SYSTEM_PROMPT, SUMMARY_MAX_TOKENS, + _content_as_text, + maybe_compact, trim_for_context, ) @@ -84,3 +90,105 @@ class TestTrimForContext: assert trimmed[-1]["role"] == "user" assert "pasted message was too large" in trimmed[-1]["content"] assert "old-0" not in "\n".join(str(m.get("content", "")) for m in trimmed) + + +class TestContentAsText: + def test_string_passthrough(self): + assert _content_as_text("hello") == "hello" + + def test_none_returns_empty(self): + # Assistant turns that carried only native tool_calls persist + # content as None — flattening must not raise. + assert _content_as_text(None) == "" + + def test_list_content_joins_text_blocks(self): + content = [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ] + assert _content_as_text(content) == "describe this" + + def test_unknown_type_returns_empty(self): + assert _content_as_text(42) == "" + + +class TestMaybeCompactFourthMessage: + """Regression: a multi-message conversation must not crash compaction when + a prior assistant turn used native tool_calls (content == None). This was + the '4th message stops working' bug — on a small-context model the soft + 85% threshold is crossed after a few turns, and the older half being + summarized contained a None-content assistant message, which raised + TypeError: 'NoneType' object is not subscriptable and broke the request.""" + + def _run(self, messages, *, context_length=500): + # Force compaction to trigger and stub the summary LLM call so the test + # is hermetic (no network, no real endpoint resolution). + orig_ctx = cc.get_context_length + orig_call = cc.llm_call_async + orig_resolve = cc.resolve_endpoint + orig_update = cc._update_session_history + + async def _fake_summary(*a, **k): + return "compact summary text" + + cc.get_context_length = lambda url, model: context_length + cc.llm_call_async = _fake_summary + cc.resolve_endpoint = lambda which: (None, None, None) + cc._update_session_history = lambda *a, **k: None + try: + return asyncio.run( + maybe_compact( + session=None, + endpoint_url="http://local/v1/chat/completions", + model="local-model", + messages=list(messages), + headers={}, + ) + ) + finally: + cc.get_context_length = orig_ctx + cc.llm_call_async = orig_call + cc.resolve_endpoint = orig_resolve + cc._update_session_history = orig_update + + def _four_turn_history_with_tool_call(self): + # Large system prompt so the conversation crosses the 85% threshold of + # the tiny (context_length=500) window used in _run, forcing the real + # compaction branch to execute. + return [ + {"role": "system", "content": "You are a helpful agent. " * 200}, + {"role": "user", "content": "turn 1: search the web"}, + # Native tool call → content is None (matches agent_loop persistence) + {"role": "assistant", "content": None, + "tool_calls": [{"id": "c1", "type": "function", + "function": {"name": "web_search", "arguments": "{}"}}]}, + {"role": "tool", "tool_call_id": "c1", "content": "search results"}, + {"role": "assistant", "content": "Here is what I found."}, + {"role": "user", "content": "turn 2"}, + {"role": "assistant", "content": "reply 2"}, + {"role": "user", "content": "turn 3"}, + {"role": "assistant", "content": "reply 3"}, + {"role": "user", "content": "turn 4 — previously broke here"}, + ] + + def test_does_not_crash_on_none_content_turn(self): + # Must not raise TypeError; returns the 3-tuple contract. + result = self._run(self._four_turn_history_with_tool_call()) + assert isinstance(result, tuple) and len(result) == 3 + compacted_messages, context_length, was_compacted = result + assert isinstance(compacted_messages, list) + assert was_compacted is True + # The summary the model produced is present and a system message. + assert any( + m.get("role") == "system" and "compact summary text" in (m.get("content") or "") + for m in compacted_messages + ) + + def test_handles_multimodal_list_content(self): + messages = self._four_turn_history_with_tool_call() + messages[1] = {"role": "user", "content": [ + {"type": "text", "text": "look at this image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,xxxx"}}, + ]} + result = self._run(messages) + assert len(result) == 3 and result[2] is True