diff --git a/src/tool_implementations.py b/src/tool_implementations.py index 3847cf8..722c39f 100644 --- a/src/tool_implementations.py +++ b/src/tool_implementations.py @@ -1215,7 +1215,17 @@ async def do_manage_mcp(content: str, owner: Optional[str] = None) -> Dict: try: srv = db2.query(McpServer).filter(McpServer.id == sid).first() if srv: - await mcp.connect_server(sid) + _args = json.loads(srv.args) if srv.args else [] + _env = json.loads(srv.env) if srv.env else {} + await mcp.connect_server( + server_id=sid, + name=srv.name, + transport=srv.transport, + command=srv.command, + args=_args, + env=_env, + url=srv.url, + ) st = mcp.get_server_status(sid) return {"response": f"Reconnected '{srv.name}' ({st.get('tool_count', 0)} tools)", "exit_code": 0} return {"error": f"Server {sid} not found", "exit_code": 1} diff --git a/tests/test_mcp_reconnect_args.py b/tests/test_mcp_reconnect_args.py new file mode 100644 index 0000000..b2a1e8b --- /dev/null +++ b/tests/test_mcp_reconnect_args.py @@ -0,0 +1,46 @@ +"""Verify that MCP reconnect via the agent tool passes full server metadata.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace + + +def test_reconnect_passes_full_server_config(): + """do_manage_mcp reconnect must pass name/transport/command/args/env/url.""" + from src.tool_implementations import do_manage_mcp + + fake_mcp = MagicMock() + fake_mcp.disconnect_server = AsyncMock() + fake_mcp.connect_server = AsyncMock(return_value=True) + fake_mcp.get_server_status = MagicMock(return_value={"tool_count": 3}) + + fake_srv = SimpleNamespace( + id="srv-123", + name="test-server", + transport="stdio", + command="/usr/bin/test", + args=json.dumps(["--flag"]), + env=json.dumps({"KEY": "val"}), + url=None, + ) + + fake_db = MagicMock() + fake_db.query.return_value.filter.return_value.first.return_value = fake_srv + + with patch("src.tool_implementations.get_mcp_manager", return_value=fake_mcp), \ + patch("core.database.SessionLocal", return_value=fake_db): + result = asyncio.run(do_manage_mcp( + json.dumps({"action": "reconnect", "server_id": "srv-123"}) + )) + + assert result["exit_code"] == 0 + fake_mcp.connect_server.assert_called_once_with( + server_id="srv-123", + name="test-server", + transport="stdio", + command="/usr/bin/test", + args=["--flag"], + env={"KEY": "val"}, + url=None, + )