Odysseus v1.0
This commit is contained in:
299
src/context_compactor.py
Normal file
299
src/context_compactor.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
context_compactor.py
|
||||
|
||||
Auto-compacts conversation history when approaching context window limits.
|
||||
Summarizes older messages via the same LLM, preserving key context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.model_context import get_context_length, estimate_tokens
|
||||
from src.llm_core import llm_call_async
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from core.models import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPACT_THRESHOLD = 0.85 # Trigger compaction at 85% of context window
|
||||
SUMMARY_MAX_TOKENS = 1024
|
||||
SMALL_CONTEXT_LIMIT = 8192 # Models with context <= this get aggressive trimming
|
||||
|
||||
# Cursor-style self-summarization prompt — produces structured, dense summaries
|
||||
SELF_SUMMARY_SYSTEM_PROMPT = """You are summarizing a conversation to preserve context after compaction. Produce a structured summary that lets the conversation continue seamlessly.
|
||||
|
||||
Use this format:
|
||||
|
||||
## Conversation Summary
|
||||
**Turns summarized:** {count} | **Compactions so far:** {n}
|
||||
|
||||
### User Goal
|
||||
One sentence describing what the user is trying to accomplish.
|
||||
|
||||
### What Was Done
|
||||
- Bullet points of completed actions, decisions made, and key outputs
|
||||
- Include specific file paths, function names, variable names, URLs, and config values
|
||||
- Note any errors encountered and how they were resolved
|
||||
|
||||
### Current State
|
||||
What is the system/code/task state right now? What was the last thing discussed?
|
||||
|
||||
### Pending / Next Steps
|
||||
- What remains to be done
|
||||
- Any open questions or blockers
|
||||
|
||||
### Key Context
|
||||
- Important constraints, preferences, or decisions that must not be forgotten
|
||||
- Specific values: model names, ports, paths, credentials references, versions
|
||||
|
||||
Keep the summary under 1000 tokens. Be dense — every token should carry information. Do not include pleasantries or meta-commentary."""
|
||||
|
||||
|
||||
def _sanitize_tool_messages(msgs: List[Dict]) -> List[Dict]:
|
||||
"""Drop orphaned `tool` messages and dangling assistant `tool_calls`.
|
||||
|
||||
OpenAI's API requires every `role:"tool"` message to immediately
|
||||
follow an assistant message that carries `tool_calls` (or another
|
||||
tool message in the same batch). Front-trimming the history can cut
|
||||
the assistant `tool_calls` parent while keeping its tool responses,
|
||||
which triggers: "messages with role 'tool' must be a response to a
|
||||
preceding message with 'tool_calls'". This pass repairs that:
|
||||
- drops `tool` messages with no valid preceding tool_calls
|
||||
- drops assistant `tool_calls` messages whose tool responses were
|
||||
all trimmed away (some providers reject unanswered tool_calls)
|
||||
"""
|
||||
# Pass 1: drop orphan tool messages.
|
||||
cleaned: List[Dict] = []
|
||||
in_batch = False # are we right after an assistant tool_calls (or mid-batch)?
|
||||
for m in msgs:
|
||||
role = m.get("role")
|
||||
if role == "tool":
|
||||
if in_batch:
|
||||
cleaned.append(m)
|
||||
# else: orphan — drop
|
||||
continue
|
||||
if role == "assistant" and m.get("tool_calls"):
|
||||
in_batch = True
|
||||
else:
|
||||
in_batch = False
|
||||
cleaned.append(m)
|
||||
|
||||
# Pass 2: drop assistant tool_calls messages that have NO following
|
||||
# tool response (dangling) — walk backwards so we know what follows.
|
||||
out: List[Dict] = []
|
||||
for i, m in enumerate(cleaned):
|
||||
if m.get("role") == "assistant" and m.get("tool_calls"):
|
||||
nxt = cleaned[i + 1] if i + 1 < len(cleaned) else None
|
||||
if not (nxt and nxt.get("role") == "tool"):
|
||||
# Dangling tool_calls — keep the message but strip the
|
||||
# tool_calls so it's a plain assistant turn (preserves any
|
||||
# text content the model produced alongside the calls).
|
||||
m = {k: v for k, v in m.items() if k != "tool_calls"}
|
||||
if not (m.get("content") or "").strip():
|
||||
continue # nothing left worth keeping
|
||||
out.append(m)
|
||||
return out
|
||||
|
||||
|
||||
def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens: int = 512) -> List[Dict]:
|
||||
"""Trim system messages to fit within context_length.
|
||||
|
||||
For small-context models, progressively strips:
|
||||
1. RAG/memory system messages (keep preset system prompt)
|
||||
2. Older conversation turns
|
||||
Reserves space for the response.
|
||||
"""
|
||||
budget = context_length - reserve_tokens
|
||||
used = estimate_tokens(messages)
|
||||
if used <= budget:
|
||||
return messages
|
||||
|
||||
logger.info(f"Trimming messages: {used} tokens > {budget} budget (ctx={context_length})")
|
||||
|
||||
# Separate system messages from conversation.
|
||||
# Messages marked _protected (e.g. active document) are never trimmed.
|
||||
system_msgs = []
|
||||
protected_msgs = []
|
||||
convo_msgs = []
|
||||
for msg in messages:
|
||||
if msg.get("_protected"):
|
||||
protected_msgs.append(msg)
|
||||
elif msg.get("role") == "system":
|
||||
system_msgs.append(msg)
|
||||
else:
|
||||
convo_msgs.append(msg)
|
||||
|
||||
# Protected messages count toward budget but are never dropped
|
||||
protected_tokens = estimate_tokens(protected_msgs)
|
||||
budget -= protected_tokens
|
||||
|
||||
# Priority: keep first system msg (preset prompt), drop others (memory, RAG, memo)
|
||||
essential_system = system_msgs[:1] if system_msgs else []
|
||||
extra_system = system_msgs[1:]
|
||||
|
||||
# Try dropping extra system messages one by one (from the end)
|
||||
trimmed = essential_system + convo_msgs
|
||||
if estimate_tokens(trimmed) <= budget:
|
||||
# Dropping extras was enough — try adding back some
|
||||
result = list(essential_system)
|
||||
for msg in extra_system:
|
||||
candidate = result + [msg] + convo_msgs
|
||||
if estimate_tokens(candidate) <= budget:
|
||||
result.append(msg)
|
||||
else:
|
||||
break
|
||||
return _sanitize_tool_messages(result + protected_msgs + convo_msgs)
|
||||
|
||||
# Still too big — truncate the first system message (but keep more than 500 chars)
|
||||
if essential_system:
|
||||
sys_text = essential_system[0].get("content", "")
|
||||
if len(sys_text) > 2000:
|
||||
essential_system[0] = {"role": "system", "content": sys_text[:2000] + "\n[System prompt truncated for context limits]"}
|
||||
trimmed = essential_system + convo_msgs
|
||||
if estimate_tokens(trimmed) <= budget:
|
||||
return _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs)
|
||||
|
||||
# Still too big — drop older conversation turns BUT protect the last 10.
|
||||
# Hermes-style: recent context matters more than old context.
|
||||
PROTECT_RECENT = 10
|
||||
if len(convo_msgs) > PROTECT_RECENT:
|
||||
old_msgs = convo_msgs[:-PROTECT_RECENT]
|
||||
recent_msgs = convo_msgs[-PROTECT_RECENT:]
|
||||
while old_msgs and estimate_tokens(essential_system + old_msgs + recent_msgs) > budget:
|
||||
old_msgs.pop(0)
|
||||
convo_msgs = old_msgs + recent_msgs
|
||||
else:
|
||||
# Not enough messages to split — just trim from front
|
||||
while convo_msgs and estimate_tokens(essential_system + convo_msgs) > budget:
|
||||
convo_msgs.pop(0)
|
||||
|
||||
result = _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs)
|
||||
logger.info(f"Trimmed to {estimate_tokens(result)} tokens ({len(result)} messages)")
|
||||
return result
|
||||
|
||||
|
||||
async def maybe_compact(
|
||||
session,
|
||||
endpoint_url: str,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
headers: Optional[Dict] = None,
|
||||
) -> tuple:
|
||||
"""Check context usage and compact if above threshold.
|
||||
|
||||
Returns (messages, context_length, was_compacted).
|
||||
"""
|
||||
context_length = get_context_length(endpoint_url, model)
|
||||
used = estimate_tokens(messages)
|
||||
pct = (used / context_length) * 100 if context_length else 0
|
||||
|
||||
if pct < COMPACT_THRESHOLD * 100:
|
||||
return messages, context_length, False
|
||||
|
||||
logger.info(
|
||||
f"Context at {pct:.1f}% ({used}/{context_length} tokens) — compacting"
|
||||
)
|
||||
|
||||
# Split into system preface and conversation
|
||||
system_msgs = []
|
||||
convo_msgs = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_msgs.append(msg)
|
||||
else:
|
||||
convo_msgs.append(msg)
|
||||
|
||||
if len(convo_msgs) < 4:
|
||||
return messages, context_length, False
|
||||
|
||||
# Split conversation: summarize older half, keep recent half
|
||||
split_point = len(convo_msgs) // 2
|
||||
older = convo_msgs[:split_point]
|
||||
recent = convo_msgs[split_point:]
|
||||
|
||||
# Build the text to summarize
|
||||
convo_text = "\n".join(
|
||||
f"{msg['role'].upper()}: {msg.get('content', '')[:2000]}"
|
||||
for msg in older
|
||||
)
|
||||
|
||||
# Count prior compactions from existing summary messages
|
||||
compaction_count = sum(
|
||||
1 for m in system_msgs
|
||||
if "[Conversation summary" in m.get("content", "")
|
||||
)
|
||||
|
||||
# Use utility model if configured, otherwise fall back to session model
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
||||
compact_url = util_url or endpoint_url
|
||||
compact_model = util_model or model
|
||||
compact_headers = util_headers if util_url else headers
|
||||
|
||||
prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace(
|
||||
"{count}", str(len(older))
|
||||
).replace(
|
||||
"{n}", str(compaction_count + 1)
|
||||
)
|
||||
summary_messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": convo_text},
|
||||
]
|
||||
|
||||
try:
|
||||
summary = await llm_call_async(
|
||||
compact_url,
|
||||
compact_model,
|
||||
summary_messages,
|
||||
temperature=0.2,
|
||||
max_tokens=SUMMARY_MAX_TOKENS,
|
||||
headers=compact_headers,
|
||||
timeout=30,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Compaction summary failed: {e}")
|
||||
return system_msgs + recent, context_length, False
|
||||
|
||||
summary_msg = {
|
||||
"role": "system",
|
||||
"content": f"[Conversation summary — earlier messages were compacted]\n{summary}",
|
||||
}
|
||||
|
||||
compacted = system_msgs + [summary_msg] + recent
|
||||
|
||||
# Update session history to match
|
||||
_update_session_history(session, split_point, summary)
|
||||
|
||||
new_used = estimate_tokens(compacted)
|
||||
logger.info(
|
||||
f"Compacted: {used} -> {new_used} tokens "
|
||||
f"({len(older)} messages summarized, {len(recent)} kept)"
|
||||
)
|
||||
|
||||
return compacted, context_length, True
|
||||
|
||||
|
||||
def _update_session_history(session, split_point: int, summary: str):
|
||||
"""Update the in-memory session history after compaction."""
|
||||
if not session or not hasattr(session, "history"):
|
||||
return
|
||||
|
||||
if split_point >= len(session.history):
|
||||
return
|
||||
|
||||
# Keep the recent messages, prepend summary
|
||||
recent_history = session.history[split_point:]
|
||||
summary_msg = ChatMessage(
|
||||
role="system",
|
||||
content=f"[Conversation summary]\n{summary}",
|
||||
metadata={"compacted": True, "summarized_count": split_point},
|
||||
)
|
||||
new_history = [summary_msg] + recent_history
|
||||
try:
|
||||
from core import models as _core_models
|
||||
manager = getattr(_core_models, "_session_manager", None)
|
||||
except Exception:
|
||||
manager = None
|
||||
if manager and getattr(session, "id", None):
|
||||
if manager.replace_messages(session.id, new_history):
|
||||
return
|
||||
session.history = new_history
|
||||
Reference in New Issue
Block a user