diff --git a/src/mcp_manager.py b/src/mcp_manager.py index 3cddfab..e588a10 100644 --- a/src/mcp_manager.py +++ b/src/mcp_manager.py @@ -43,6 +43,8 @@ class McpManager: self._sessions: Dict[str, Any] = {} # server_id -> exit stack (for cleanup) self._stacks: Dict[str, Any] = {} + # Tracking updates to tools/connections for RAG indexing + self._generation = 0 async def connect_server( self, @@ -57,16 +59,20 @@ class McpManager: """Connect to an MCP server via stdio or SSE transport.""" try: if transport == "stdio": - return await self._connect_stdio(server_id, name, command, args or [], env or {}) + res = await self._connect_stdio(server_id, name, command, args or [], env or {}) elif transport == "sse": - return await self._connect_sse(server_id, name, url) + res = await self._connect_sse(server_id, name, url) else: logger.error(f"Unknown MCP transport: {transport}") - return False + res = False + if res: + self._generation += 1 + return res except Exception as e: logger.error(f"Failed to connect MCP server {name} ({server_id}): {e}") error_message = _format_mcp_connection_error(name, command or "", args or [], e) self._connections[server_id] = {"status": "error", "error": error_message, "name": name} + self._generation += 1 return False async def _connect_stdio(self, server_id: str, name: str, command: str, args: List[str], env: Dict[str, str]) -> bool: @@ -182,6 +188,7 @@ class McpManager: self._sessions.pop(server_id, None) self._tools.pop(server_id, None) self._connections.pop(server_id, None) + self._generation += 1 logger.info(f"MCP server disconnected: {server_id}") async def disconnect_all(self): @@ -387,7 +394,11 @@ class McpManager: def get_tool_descriptions_for_prompt(self, disabled_map: Optional[Dict[str, set]] = None) -> str: """Generate text describing MCP tools for the agent system prompt. Cached.""" - cache_key = (frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()), len(self._tools)) + cache_key = ( + frozenset((k, frozenset(v)) for k, v in (disabled_map or {}).items()), + len(self._tools), + self._generation, + ) if self._cached_prompt_desc is not None and self._cached_prompt_desc_key == cache_key: return self._cached_prompt_desc tools = self.get_all_tools(disabled_map) diff --git a/tests/test_mcp_cache_invalidation.py b/tests/test_mcp_cache_invalidation.py new file mode 100644 index 0000000..3324e92 --- /dev/null +++ b/tests/test_mcp_cache_invalidation.py @@ -0,0 +1,71 @@ +"""Regression test: McpManager._generation must bust the tool prompt cache +when a server connects/disconnects with the same tool count. + +Before the fix, cache_key was (disabled_map, len(_tools)). A reconnect that +preserved the tool count left the stale description in place. After the fix +the _generation counter is included so any structural change invalidates it. +""" +import asyncio + +from src.mcp_manager import McpManager + + +def _make_mgr(): + return McpManager() + + +def _inject_tools(mgr, server_id: str, tools: list): + """Directly populate internal dicts as _connect_stdio would after success.""" + mgr._tools[server_id] = tools + mgr._connections[server_id] = {"status": "connected", "name": server_id} + + +# --------------------------------------------------------------------------- +# _generation increments on disconnect +# --------------------------------------------------------------------------- + +def test_generation_increments_on_disconnect(): + mgr = _make_mgr() + assert mgr._generation == 0 + _inject_tools(mgr, "srv1", [{"name": "tool_a"}]) + mgr._generation += 1 # simulate connect increment + + gen_before = mgr._generation + asyncio.run(mgr.disconnect_server("srv1")) + assert mgr._generation == gen_before + 1 + + +# --------------------------------------------------------------------------- +# Core cache-invalidation regression: stale description after reconnect +# --------------------------------------------------------------------------- + +def test_prompt_cache_busted_after_disconnect_same_tool_count(): + """The stale-cache bug: two different servers each have 1 tool. + After the first disconnects and the second connects, the cache must + reflect the new server's tools, not the old one's description. + """ + mgr = _make_mgr() + + # Connect server A with one tool + _inject_tools(mgr, "srv_a", [{"name": "tool_alpha", "description": "Alpha tool", + "inputSchema": {"type": "object", "properties": {}}}]) + mgr._generation += 1 # simulated successful connect + + desc_a = mgr.get_tool_descriptions_for_prompt() + assert "tool_alpha" in desc_a + + # Disconnect A — same tool count (1) as what follows + asyncio.run(mgr.disconnect_server("srv_a")) # bumps _generation + + # Connect server B with a *different* tool but same count (1) + _inject_tools(mgr, "srv_b", [{"name": "tool_beta", "description": "Beta tool", + "inputSchema": {"type": "object", "properties": {}}}]) + mgr._generation += 1 # simulated successful connect + + desc_b = mgr.get_tool_descriptions_for_prompt() + + # Without the fix both describe tool_alpha (stale cache hit). + assert "tool_beta" in desc_b, ( + "Cache was not invalidated: got stale description after reconnect" + ) + assert "tool_alpha" not in desc_b