fix(mcp): invalidate tool prompt cache on connect/disconnect/error (#1235)
* fix(mcp): invalidate tool prompt cache on connect/disconnect/error get_tool_descriptions_for_prompt cached its result keyed only on (disabled_map, len(_tools)). If a server reconnects with the same tool count (or transitions to error state), the cache was never busted — the agent received stale tool descriptions for the new connection state. Add a _generation counter incremented on every structural change (successful connect, disconnect, connection error) and include it in the cache key. * test(mcp): regression test for _generation cache invalidation
This commit is contained in:
@@ -43,6 +43,8 @@ class McpManager:
|
||||
self._sessions: Dict[str, Any] = {}
|
||||
# server_id -> exit stack (for cleanup)
|
||||
self._stacks: Dict[str, Any] = {}
|
||||
# Tracking updates to tools/connections for RAG indexing
|
||||
self._generation = 0
|
||||
|
||||
async def connect_server(
|
||||
self,
|
||||
@@ -57,16 +59,20 @@ class McpManager:
|
||||
"""Connect to an MCP server via stdio or SSE transport."""
|
||||
try:
|
||||
if transport == "stdio":
|
||||
return await self._connect_stdio(server_id, name, command, args or [], env or {})
|
||||
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
|
||||
elif transport == "sse":
|
||||
return await self._connect_sse(server_id, name, url)
|
||||
res = await self._connect_sse(server_id, name, url)
|
||||
else:
|
||||
logger.error(f"Unknown MCP transport: {transport}")
|
||||
return False
|
||||
res = False
|
||||
if res:
|
||||
self._generation += 1
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect MCP server {name} ({server_id}): {e}")
|
||||
error_message = _format_mcp_connection_error(name, command or "", args or [], e)
|
||||
self._connections[server_id] = {"status": "error", "error": error_message, "name": name}
|
||||
self._generation += 1
|
||||
return False
|
||||
|
||||
async def _connect_stdio(self, server_id: str, name: str, command: str, args: List[str], env: Dict[str, str]) -> bool:
|
||||
@@ -182,6 +188,7 @@ class McpManager:
|
||||
self._sessions.pop(server_id, None)
|
||||
self._tools.pop(server_id, None)
|
||||
self._connections.pop(server_id, None)
|
||||
self._generation += 1
|
||||
logger.info(f"MCP server disconnected: {server_id}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
@@ -387,7 +394,11 @@ class McpManager:
|
||||
|
||||
def get_tool_descriptions_for_prompt(self, disabled_map: Optional[Dict[str, set]] = None) -> str:
|
||||
"""Generate text describing MCP tools for the agent system prompt. Cached."""
|
||||
cache_key = (frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()), len(self._tools))
|
||||
cache_key = (
|
||||
frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()),
|
||||
len(self._tools),
|
||||
self._generation,
|
||||
)
|
||||
if self._cached_prompt_desc is not None and self._cached_prompt_desc_key == cache_key:
|
||||
return self._cached_prompt_desc
|
||||
tools = self.get_all_tools(disabled_map)
|
||||
|
||||
Reference in New Issue
Block a user