diff --git a/mcp_servers/rag_server.py b/mcp_servers/rag_server.py index d70aa1c..1dfd464 100644 --- a/mcp_servers/rag_server.py +++ b/mcp_servers/rag_server.py @@ -101,7 +101,8 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error: {e}")] elif action == "add_directory": - directory = arguments.get("directory", "").strip() + _dir = arguments.get("directory") + directory = _dir.strip() if isinstance(_dir, str) else "" if not directory: return [TextContent(type="text", text="Error: add_directory needs a directory path")] directory = os.path.expanduser(directory) @@ -126,7 +127,8 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error: Failed to index directory: {e}")] elif action == "remove_directory": - directory = arguments.get("directory", "").strip() + _dir = arguments.get("directory") + directory = _dir.strip() if isinstance(_dir, str) else "" if not directory: return [TextContent(type="text", text="Error: remove_directory needs a directory path")] # Expand ~ to match add_directory, which indexes the expanded path. diff --git a/tests/test_rag_server_directory_nonstring.py b/tests/test_rag_server_directory_nonstring.py new file mode 100644 index 0000000..4311cf5 --- /dev/null +++ b/tests/test_rag_server_directory_nonstring.py @@ -0,0 +1,28 @@ +"""Regression: rag_server add/remove_directory must not crash on a non-string path. + +`directory = arguments.get("directory", "").strip()` runs before the surrounding +try, so a non-string `directory` in the tool args (e.g. a number) raised +AttributeError out of call_tool. Coerce non-strings to "". +""" +import asyncio + +import pytest + +pytest.importorskip("mcp") + +import mcp_servers.rag_server as rs + + +def _call(monkeypatch, action, directory): + monkeypatch.setattr(rs, "_ensure_init", lambda: None) + return asyncio.run(rs.call_tool("manage_rag", {"action": action, "directory": directory})) + + +def test_add_directory_non_string_does_not_crash(monkeypatch): + out = _call(monkeypatch, "add_directory", 123) + assert "needs a directory path" in out[0].text + + +def test_remove_directory_non_string_does_not_crash(monkeypatch): + out = _call(monkeypatch, "remove_directory", ["x"]) + assert "needs a directory path" in out[0].text