fix(mcp): invalidate tool prompt cache on connect/disconnect/error (#1235)

* fix(mcp): invalidate tool prompt cache on connect/disconnect/error

get_tool_descriptions_for_prompt cached its result keyed only on
(disabled_map, len(_tools)). If a server reconnects with the same
tool count (or transitions to error state), the cache was never
busted — the agent received stale tool descriptions for the new
connection state.

Add a _generation counter incremented on every structural change
(successful connect, disconnect, connection error) and include it in
the cache key.

* test(mcp): regression test for _generation cache invalidation
This commit is contained in:
Shreyas S Joshi
2026-06-02 21:19:29 +05:30
committed by GitHub
parent 77320b617f
commit b29c200801
2 changed files with 86 additions and 4 deletions

View File

@@ -43,6 +43,8 @@ class McpManager:
self._sessions: Dict[str, Any] = {}
# server_id -> exit stack (for cleanup)
self._stacks: Dict[str, Any] = {}
# Tracking updates to tools/connections for RAG indexing
self._generation = 0
async def connect_server(
self,
@@ -57,16 +59,20 @@ class McpManager:
"""Connect to an MCP server via stdio or SSE transport."""
try:
if transport == "stdio":
return await self._connect_stdio(server_id, name, command, args or [], env or {})
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
elif transport == "sse":
return await self._connect_sse(server_id, name, url)
res = await self._connect_sse(server_id, name, url)
else:
logger.error(f"Unknown MCP transport: {transport}")
return False
res = False
if res:
self._generation += 1
return res
except Exception as e:
logger.error(f"Failed to connect MCP server {name} ({server_id}): {e}")
error_message = _format_mcp_connection_error(name, command or "", args or [], e)
self._connections[server_id] = {"status": "error", "error": error_message, "name": name}
self._generation += 1
return False
async def _connect_stdio(self, server_id: str, name: str, command: str, args: List[str], env: Dict[str, str]) -> bool:
@@ -182,6 +188,7 @@ class McpManager:
self._sessions.pop(server_id, None)
self._tools.pop(server_id, None)
self._connections.pop(server_id, None)
self._generation += 1
logger.info(f"MCP server disconnected: {server_id}")
async def disconnect_all(self):
@@ -387,7 +394,11 @@ class McpManager:
def get_tool_descriptions_for_prompt(self, disabled_map: Optional[Dict[str, set]] = None) -> str:
"""Generate text describing MCP tools for the agent system prompt. Cached."""
cache_key = (frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()), len(self._tools))
cache_key = (
frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()),
len(self._tools),
self._generation,
)
if self._cached_prompt_desc is not None and self._cached_prompt_desc_key == cache_key:
return self._cached_prompt_desc
tools = self.get_all_tools(disabled_map)