fix(mcp): invalidate tool prompt cache on connect/disconnect/error (#1235)
* fix(mcp): invalidate tool prompt cache on connect/disconnect/error get_tool_descriptions_for_prompt cached its result keyed only on (disabled_map, len(_tools)). If a server reconnects with the same tool count (or transitions to error state), the cache was never busted — the agent received stale tool descriptions for the new connection state. Add a _generation counter incremented on every structural change (successful connect, disconnect, connection error) and include it in the cache key. * test(mcp): regression test for _generation cache invalidation
This commit is contained in:
@@ -43,6 +43,8 @@ class McpManager:
|
|||||||
self._sessions: Dict[str, Any] = {}
|
self._sessions: Dict[str, Any] = {}
|
||||||
# server_id -> exit stack (for cleanup)
|
# server_id -> exit stack (for cleanup)
|
||||||
self._stacks: Dict[str, Any] = {}
|
self._stacks: Dict[str, Any] = {}
|
||||||
|
# Tracking updates to tools/connections for RAG indexing
|
||||||
|
self._generation = 0
|
||||||
|
|
||||||
async def connect_server(
|
async def connect_server(
|
||||||
self,
|
self,
|
||||||
@@ -57,16 +59,20 @@ class McpManager:
|
|||||||
"""Connect to an MCP server via stdio or SSE transport."""
|
"""Connect to an MCP server via stdio or SSE transport."""
|
||||||
try:
|
try:
|
||||||
if transport == "stdio":
|
if transport == "stdio":
|
||||||
return await self._connect_stdio(server_id, name, command, args or [], env or {})
|
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
|
||||||
elif transport == "sse":
|
elif transport == "sse":
|
||||||
return await self._connect_sse(server_id, name, url)
|
res = await self._connect_sse(server_id, name, url)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown MCP transport: {transport}")
|
logger.error(f"Unknown MCP transport: {transport}")
|
||||||
return False
|
res = False
|
||||||
|
if res:
|
||||||
|
self._generation += 1
|
||||||
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect MCP server {name} ({server_id}): {e}")
|
logger.error(f"Failed to connect MCP server {name} ({server_id}): {e}")
|
||||||
error_message = _format_mcp_connection_error(name, command or "", args or [], e)
|
error_message = _format_mcp_connection_error(name, command or "", args or [], e)
|
||||||
self._connections[server_id] = {"status": "error", "error": error_message, "name": name}
|
self._connections[server_id] = {"status": "error", "error": error_message, "name": name}
|
||||||
|
self._generation += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _connect_stdio(self, server_id: str, name: str, command: str, args: List[str], env: Dict[str, str]) -> bool:
|
async def _connect_stdio(self, server_id: str, name: str, command: str, args: List[str], env: Dict[str, str]) -> bool:
|
||||||
@@ -182,6 +188,7 @@ class McpManager:
|
|||||||
self._sessions.pop(server_id, None)
|
self._sessions.pop(server_id, None)
|
||||||
self._tools.pop(server_id, None)
|
self._tools.pop(server_id, None)
|
||||||
self._connections.pop(server_id, None)
|
self._connections.pop(server_id, None)
|
||||||
|
self._generation += 1
|
||||||
logger.info(f"MCP server disconnected: {server_id}")
|
logger.info(f"MCP server disconnected: {server_id}")
|
||||||
|
|
||||||
async def disconnect_all(self):
|
async def disconnect_all(self):
|
||||||
@@ -387,7 +394,11 @@ class McpManager:
|
|||||||
|
|
||||||
def get_tool_descriptions_for_prompt(self, disabled_map: Optional[Dict[str, set]] = None) -> str:
|
def get_tool_descriptions_for_prompt(self, disabled_map: Optional[Dict[str, set]] = None) -> str:
|
||||||
"""Generate text describing MCP tools for the agent system prompt. Cached."""
|
"""Generate text describing MCP tools for the agent system prompt. Cached."""
|
||||||
cache_key = (frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()), len(self._tools))
|
cache_key = (
|
||||||
|
frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()),
|
||||||
|
len(self._tools),
|
||||||
|
self._generation,
|
||||||
|
)
|
||||||
if self._cached_prompt_desc is not None and self._cached_prompt_desc_key == cache_key:
|
if self._cached_prompt_desc is not None and self._cached_prompt_desc_key == cache_key:
|
||||||
return self._cached_prompt_desc
|
return self._cached_prompt_desc
|
||||||
tools = self.get_all_tools(disabled_map)
|
tools = self.get_all_tools(disabled_map)
|
||||||
|
|||||||
71
tests/test_mcp_cache_invalidation.py
Normal file
71
tests/test_mcp_cache_invalidation.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Regression test: McpManager._generation must bust the tool prompt cache
|
||||||
|
when a server connects/disconnects with the same tool count.
|
||||||
|
|
||||||
|
Before the fix, cache_key was (disabled_map, len(_tools)). A reconnect that
|
||||||
|
preserved the tool count left the stale description in place. After the fix
|
||||||
|
the _generation counter is included so any structural change invalidates it.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from src.mcp_manager import McpManager
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mgr():
|
||||||
|
return McpManager()
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_tools(mgr, server_id: str, tools: list):
|
||||||
|
"""Directly populate internal dicts as _connect_stdio would after success."""
|
||||||
|
mgr._tools[server_id] = tools
|
||||||
|
mgr._connections[server_id] = {"status": "connected", "name": server_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _generation increments on disconnect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_generation_increments_on_disconnect():
|
||||||
|
mgr = _make_mgr()
|
||||||
|
assert mgr._generation == 0
|
||||||
|
_inject_tools(mgr, "srv1", [{"name": "tool_a"}])
|
||||||
|
mgr._generation += 1 # simulate connect increment
|
||||||
|
|
||||||
|
gen_before = mgr._generation
|
||||||
|
asyncio.run(mgr.disconnect_server("srv1"))
|
||||||
|
assert mgr._generation == gen_before + 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core cache-invalidation regression: stale description after reconnect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_prompt_cache_busted_after_disconnect_same_tool_count():
|
||||||
|
"""The stale-cache bug: two different servers each have 1 tool.
|
||||||
|
After the first disconnects and the second connects, the cache must
|
||||||
|
reflect the new server's tools, not the old one's description.
|
||||||
|
"""
|
||||||
|
mgr = _make_mgr()
|
||||||
|
|
||||||
|
# Connect server A with one tool
|
||||||
|
_inject_tools(mgr, "srv_a", [{"name": "tool_alpha", "description": "Alpha tool",
|
||||||
|
"inputSchema": {"type": "object", "properties": {}}}])
|
||||||
|
mgr._generation += 1 # simulated successful connect
|
||||||
|
|
||||||
|
desc_a = mgr.get_tool_descriptions_for_prompt()
|
||||||
|
assert "tool_alpha" in desc_a
|
||||||
|
|
||||||
|
# Disconnect A — same tool count (1) as what follows
|
||||||
|
asyncio.run(mgr.disconnect_server("srv_a")) # bumps _generation
|
||||||
|
|
||||||
|
# Connect server B with a *different* tool but same count (1)
|
||||||
|
_inject_tools(mgr, "srv_b", [{"name": "tool_beta", "description": "Beta tool",
|
||||||
|
"inputSchema": {"type": "object", "properties": {}}}])
|
||||||
|
mgr._generation += 1 # simulated successful connect
|
||||||
|
|
||||||
|
desc_b = mgr.get_tool_descriptions_for_prompt()
|
||||||
|
|
||||||
|
# Without the fix both describe tool_alpha (stale cache hit).
|
||||||
|
assert "tool_beta" in desc_b, (
|
||||||
|
"Cache was not invalidated: got stale description after reconnect"
|
||||||
|
)
|
||||||
|
assert "tool_alpha" not in desc_b
|
||||||
Reference in New Issue
Block a user