diff --git a/core/database.py b/core/database.py index 4788a45..5c33422 100644 --- a/core/database.py +++ b/core/database.py @@ -375,6 +375,7 @@ class McpServer(TimestampMixin, Base): is_enabled = Column(Boolean, default=True) oauth_config = Column(Text, nullable=True) # JSON: provider, keys_file, token_file, scopes disabled_tools = Column(Text, nullable=True) # JSON array of tool names to hide from LLM + oauth_tokens = Column(EncryptedText, nullable=True) # JSON {tokens, client_info} for generic MCP OAuth, encrypted at rest class Comparison(TimestampMixin, Base): @@ -1311,6 +1312,23 @@ def _migrate_add_disabled_tools(): except Exception as e: logging.getLogger(__name__).warning(f"disabled_tools migration: {e}") +def _migrate_add_mcp_oauth_tokens_column(): + """Add oauth_tokens column to mcp_servers table if missing. + + The model declares this column as EncryptedText, but the SQL type is plain + TEXT on purpose: EncryptedText is a SQLAlchemy TypeDecorator that encrypts at + the Python layer and stores the ciphertext as TEXT, so the DB column type is + TEXT. This matches the existing encrypted columns (see _migrate_encrypt_*).""" + try: + with engine.connect() as conn: + cols = [r[1] for r in conn.execute(text("PRAGMA table_info(mcp_servers)"))] + if "oauth_tokens" not in cols: + conn.execute(text("ALTER TABLE mcp_servers ADD COLUMN oauth_tokens TEXT")) + conn.commit() + logging.getLogger(__name__).info("Added oauth_tokens column to mcp_servers") + except Exception as e: + logging.getLogger(__name__).warning(f"oauth_tokens migration: {e}") + def _migrate_add_task_v2_columns(): """Add cron_expression, then_task_id, webhook_token to scheduled_tasks.""" new_cols = { @@ -1589,6 +1607,7 @@ def init_db(): _migrate_add_oauth_config() _migrate_add_task_automation_columns() _migrate_add_disabled_tools() + _migrate_add_mcp_oauth_tokens_column() _migrate_add_task_v2_columns() _migrate_add_notifications_enabled() _migrate_drop_ping_notes_tasks() diff --git a/routes/mcp_routes.py b/routes/mcp_routes.py index 003559a..e3a73c8 100644 --- a/routes/mcp_routes.py +++ b/routes/mcp_routes.py @@ -141,6 +141,7 @@ def setup_mcp_routes(mcp_manager: McpManager): "disabled_tool_count": len(disabled_list), "enabled_tool_count": max(0, total_tools - len(disabled_list)), "error": status.get("error"), + "auth_url": status.get("auth_url"), "has_oauth": oauth_cfg is not None, "needs_oauth": needs_oauth, }) @@ -171,6 +172,8 @@ def setup_mcp_routes(mcp_manager: McpManager): raise HTTPException(400, "command is required for stdio transport") if transport == "sse" and not url: raise HTTPException(400, "url is required for SSE transport") + if transport == "http" and not url: + raise HTTPException(400, "url is required for HTTP transport") # Parse JSON fields try: @@ -262,6 +265,7 @@ def setup_mcp_routes(mcp_manager: McpManager): ) status = mcp_manager.get_server_status(server_id) + needs_auth = status.get("status") == "needs_auth" return { "id": server_id, "name": name, @@ -270,6 +274,8 @@ def setup_mcp_routes(mcp_manager: McpManager): "tool_count": status.get("tool_count", 0), "error": "OAuth authorization required" if needs_oauth else status.get("error"), "needs_oauth": needs_oauth, + "needs_auth": needs_auth, + "auth_url": status.get("auth_url"), } @router.post("/servers/{server_id}/reconnect") @@ -302,6 +308,8 @@ def setup_mcp_routes(mcp_manager: McpManager): "status": status.get("status", "disconnected"), "tool_count": status.get("tool_count", 0), "error": status.get("error"), + "auth_url": status.get("auth_url"), + "needs_auth": status.get("status") == "needs_auth", } finally: db.close() @@ -467,10 +475,18 @@ def setup_mcp_routes(mcp_manager: McpManager): @router.get("/oauth/callback") async def oauth_callback(code: str, state: str, request: Request): - """Handle OAuth callback from Google — exchange code for tokens.""" + """Handle OAuth callback. Generic MCP OAuth flows resolve via the + pending-state registry; Google flows fall through to the legacy path.""" require_admin(request) - server_id = state - return await _exchange_and_connect(server_id, code, request) + from src.mcp_oauth import resolve_pending + if resolve_pending(state, code): + return HTMLResponse(_oauth_result_page( + "Authorization Successful", + "The MCP server is connecting. You can close this window and return to Odysseus.", + success=True, + )) + # Legacy Google path: state is the server_id + return await _exchange_and_connect(state, code, request) @router.post("/oauth/exchange/{server_id}") async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)): @@ -485,6 +501,17 @@ def setup_mcp_routes(mcp_manager: McpManager): except Exception: return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400) + # Generic MCP OAuth: if the pasted URL carries a state we are waiting on, + # resolve it directly (the background connect finishes the handshake). + state = params.get("state", [None])[0] + from src.mcp_oauth import resolve_pending + if state and resolve_pending(state, code): + return HTMLResponse(_oauth_result_page( + "Authorization Successful", + "The MCP server is connecting. You can close this window and return to Odysseus.", + success=True, + )) + return await _exchange_and_connect(server_id, code, request) async def _exchange_and_connect(server_id: str, code: str, request: Request): diff --git a/src/mcp_manager.py b/src/mcp_manager.py index 7cd9740..474e273 100644 --- a/src/mcp_manager.py +++ b/src/mcp_manager.py @@ -70,7 +70,9 @@ class McpManager: self._sessions: Dict[str, Any] = {} # server_id -> exit stack (for cleanup) self._stacks: Dict[str, Any] = {} - # Tracking updates to tools/connections for RAG indexing + # server_id -> background connect task (HTTP transport / OAuth) + self._connect_tasks: Dict[str, Any] = {} + # Tracking updates to tools/connections for RAG indexing / prompt cache self._generation = 0 async def connect_server( @@ -83,12 +85,14 @@ class McpManager: env: Optional[Dict[str, str]] = None, url: Optional[str] = None, ) -> bool: - """Connect to an MCP server via stdio or SSE transport.""" + """Connect to an MCP server via stdio, SSE, or Streamable HTTP transport.""" try: if transport == "stdio": res = await self._connect_stdio(server_id, name, command, args or [], env or {}) elif transport == "sse": res = await self._connect_sse(server_id, name, url) + elif transport == "http": + res = await self._start_http_connect(server_id, name, url) else: logger.error(f"Unknown MCP transport: {transport}") res = False @@ -211,8 +215,101 @@ class McpManager: self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name} return False + async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool: + """Begin a Streamable HTTP connect in the background. Returns within + `wait` seconds: True if it connected (cached-token path), otherwise the + flow is awaiting browser authorization and status becomes 'needs_auth'.""" + import asyncio + self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"} + task = asyncio.create_task(self._connect_http(server_id, name, url)) + self._connect_tasks[server_id] = task + done, _ = await asyncio.wait({task}, timeout=wait) + if task in done: + try: + return task.result() + except Exception as e: + self._connections[server_id] = {"status": "error", "error": str(e), "name": name} + return False + # Still running → either awaiting authorization, or discovery/DCR is + # still in flight. If _on_redirect already published needs_auth+auth_url, + # leave it; otherwise mark needs_auth (auth_url filled in once it fires). + from src.mcp_oauth import pop_auth_url + cur = self._connections.get(server_id, {}) + if cur.get("status") != "needs_auth": + self._connections[server_id] = { + "status": "needs_auth", "name": name, "transport": "http", + "auth_url": pop_auth_url(server_id), + } + return False + + async def _connect_http(self, server_id: str, name: str, url: str) -> bool: + """Connect to a Streamable HTTP MCP server (with automatic OAuth).""" + try: + from mcp import ClientSession + from mcp.client.streamable_http import streamablehttp_client + from contextlib import AsyncExitStack + from src.mcp_oauth import build_provider, clear_auth_url + + def _on_redirect(auth_url): + # Publish needs_auth the moment the URL is known, independent of + # how long discovery/DCR took (may exceed the bounded start wait). + self._connections[server_id] = { + "status": "needs_auth", "name": name, "transport": "http", + "auth_url": auth_url, + } + + provider = build_provider(server_id, url, on_redirect=_on_redirect) + stack = AsyncExitStack() + transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider)) + read_stream, write_stream, _get_session_id = transport + session = await stack.enter_async_context(ClientSession(read_stream, write_stream)) + await session.initialize() + + tools_result = await session.list_tools() + tools = [] + for tool in tools_result.tools: + tools.append({ + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {}, + }) + + self._sessions[server_id] = session + self._stacks[server_id] = stack + self._tools[server_id] = tools + self._connections[server_id] = { + "status": "connected", "name": name, "transport": "http", + "tool_count": len(tools), + } + clear_auth_url(server_id) + # Tools changed (this can complete after connect_server already + # returned, via the background OAuth flow), so bump the generation + # to invalidate the tool-prompt cache. + self._generation += 1 + logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http") + return True + except ImportError: + logger.warning("MCP package not installed. Install with: pip install mcp") + self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name} + return False + except Exception as e: + logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}") + self._connections[server_id] = {"status": "error", "error": str(e), "name": name} + return False + async def disconnect_server(self, server_id: str): """Disconnect from an MCP server.""" + # Cancel any in-flight HTTP/OAuth background connect so it stops + # publishing status for a server that may be getting deleted. + task = self._connect_tasks.pop(server_id, None) + if task is not None and not task.done(): + task.cancel() + try: + from src.mcp_oauth import clear_auth_url + clear_auth_url(server_id) + except Exception: + pass + stack = self._stacks.pop(server_id, None) if stack: try: diff --git a/src/mcp_oauth.py b/src/mcp_oauth.py new file mode 100644 index 0000000..9f3b2ad --- /dev/null +++ b/src/mcp_oauth.py @@ -0,0 +1,193 @@ +"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers. + +Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client +Registration, authorization-code + PKCE, token refresh) to Odysseus's web +callback route. Tokens and the dynamic registration persist per-server, +encrypted, so the interactive flow runs only once. +""" +import asyncio +import json +import logging +import os +import time +from typing import Dict, Optional, Tuple +from urllib.parse import urlparse, parse_qs + +logger = logging.getLogger(__name__) + +# OAuth redirect URI registered with every authorization server via DCR. Loopback +# is allowed for native/desktop clients (RFC 8252); remote users finish via the +# paste-back flow. Deployments not reachable at http://localhost:7000 (custom +# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or +# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back +# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host +# port-map; the app always listens on 7000 inside the container. +_REDIRECT_BASE = ( + os.environ.get("OAUTH_REDIRECT_BASE_URL") + or os.environ.get("APP_PUBLIC_URL") + or "http://localhost:7000" +).rstrip("/") +REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback" + +# How long the background connect waits for the user to authorize before giving up. +AUTH_WAIT_SECONDS = 300 + +_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)] +_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning +_auth_urls: Dict[str, str] = {} # server_id -> authorization URL + + +def _prune_stale() -> None: + """Drop abandoned flows whose authorization window has elapsed so the + module-level registries don't grow unbounded (e.g. a user who never + finishes the browser step).""" + now = time.monotonic() + for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]: + fut = _pending.pop(state, None) + _pending_ts.pop(state, None) + if fut is not None and not fut.done(): + fut.cancel() + + +def _discard_pending(state: Optional[str]) -> None: + if state is None: + return + _pending.pop(state, None) + _pending_ts.pop(state, None) + + +def register_pending(state: str) -> asyncio.Future: + _prune_stale() + fut = asyncio.get_running_loop().create_future() + _pending[state] = fut + _pending_ts[state] = time.monotonic() + return fut + + +def resolve_pending(state: str, code: str) -> bool: + fut = _pending.get(state) + if fut is not None and not fut.done(): + fut.set_result((code, state)) + return True + return False + + +def pop_auth_url(server_id: str) -> Optional[str]: + return _auth_urls.get(server_id) + + +def clear_auth_url(server_id: str) -> None: + _auth_urls.pop(server_id, None) + + +class DbTokenStorage: + """SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column.""" + + def __init__(self, server_id: str, session_factory=None): + self.server_id = server_id + if session_factory is None: + from core.database import SessionLocal + session_factory = SessionLocal + self._sf = session_factory + + def _load(self) -> dict: + from core.database import McpServer + db = self._sf() + try: + srv = db.query(McpServer).filter(McpServer.id == self.server_id).first() + if srv and srv.oauth_tokens: + return json.loads(srv.oauth_tokens) + finally: + db.close() + return {} + + def _update(self, key: str, value: dict) -> None: + """Load, set one key, and persist the oauth_tokens JSON in a single + session/commit (avoids the load+save double round-trip per write).""" + from core.database import McpServer + db = self._sf() + try: + srv = db.query(McpServer).filter(McpServer.id == self.server_id).first() + if srv is None: + return + data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {} + data[key] = value + srv.oauth_tokens = json.dumps(data) + db.commit() + finally: + db.close() + + async def get_tokens(self): + from mcp.shared.auth import OAuthToken + data = self._load().get("tokens") + return OAuthToken.model_validate(data) if data else None + + async def set_tokens(self, tokens) -> None: + self._update("tokens", json.loads(tokens.model_dump_json())) + + async def get_client_info(self): + from mcp.shared.auth import OAuthClientInformationFull + data = self._load().get("client_info") + return OAuthClientInformationFull.model_validate(data) if data else None + + async def set_client_info(self, client_info) -> None: + self._update("client_info", json.loads(client_info.model_dump_json())) + + +def build_provider(server_id: str, url: str, on_redirect=None): + """Construct an OAuthClientProvider that drives the browser flow via the + Odysseus callback route. + + on_redirect(authorization_url): optional sync callback invoked the moment + the authorization URL is known (after discovery + DCR). The manager uses it + to publish 'needs_auth' + auth_url to connection state regardless of how + long discovery/DCR took. + """ + from mcp.client.auth import OAuthClientProvider + from mcp.shared.auth import OAuthClientMetadata + + client_metadata = OAuthClientMetadata( + client_name="Odysseus", + redirect_uris=[REDIRECT_URI], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + # Leave scope unset: the SDK applies the MCP scope-selection strategy and + # overwrites this from the server's WWW-Authenticate / protected-resource + # metadata before building the auth URL. Hardcoding an OIDC scope here + # would break the many MCP servers that are not OpenID providers. + scope=None, + token_endpoint_auth_method="none", + ) + + async def redirect_handler(authorization_url: str) -> None: + state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0] + if state: + register_pending(state) + _auth_urls[server_id] = authorization_url + if on_redirect is not None: + try: + on_redirect(authorization_url) + except Exception as e: + logger.warning(f"MCP OAuth on_redirect callback failed: {e}") + logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})") + + async def callback_handler() -> Tuple[str, Optional[str]]: + auth_url = _auth_urls.get(server_id) + state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None + fut = _pending.get(state) + if fut is None: + raise RuntimeError("No pending OAuth flow for this server") + try: + code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS) + return code, ret_state + finally: + _discard_pending(state) + _auth_urls.pop(server_id, None) + + return OAuthClientProvider( + server_url=url, + client_metadata=client_metadata, + storage=DbTokenStorage(server_id), + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) diff --git a/static/js/settings.js b/static/js/settings.js index 8a53606..3a6e9d0 100644 --- a/static/js/settings.js +++ b/static/js/settings.js @@ -4448,6 +4448,68 @@ async function initUnifiedIntegrations() { // ── MCP form — full management view ── async function showMcpForm(editId) { + // Toggle an in-flight loading state on a button (disabled + dimmed + label). + function _setBtnLoading(btn, loading, label) { + if (!btn) return; + btn.disabled = loading; + btn.style.opacity = loading ? '0.6' : ''; + btn.style.cursor = loading ? 'progress' : ''; + if (label != null) btn.textContent = label; + } + function _showMcpPasteback(id) { + const msg = el('uf-mcp-msg'); if (!msg) return; + if (el('uf-mcp-pasteback')) return; // already shown + msg.innerHTML = + 'Authorize in the opened tab. If the redirect fails (remote access), paste the resulting URL here: ' + + '' + + ''; + const pasteGo = el('uf-mcp-paste-go'); + if (pasteGo) pasteGo.addEventListener('click', async () => { + const cb = el('uf-mcp-pasteback').value.trim(); + if (!cb) return; + const pf = new FormData(); pf.append('callback_url', cb); + _setBtnLoading(pasteGo, true, 'Submitting…'); + try { + await fetch(`/api/mcp/oauth/exchange/${id}`, { method: 'POST', credentials: 'same-origin', body: pf }); + } finally { + _setBtnLoading(pasteGo, false, 'Submit'); + } + }); + } + + // Drives the OAuth flow: waits for the auth_url (discovery+DCR may lag), + // opens it once, then resolves on connected/error. + async function _handleMcpAuth(id, initialAuthUrl, tries = 90) { + let opened = false; + const openAuth = (u) => { if (!opened && u) { opened = true; window.open(u, '_blank', 'noopener'); _showMcpPasteback(id); } }; + openAuth(initialAuthUrl); + const msg = el('uf-mcp-msg'); + let fails = 0; + for (let i = 0; i < tries; i++) { + await new Promise(res => setTimeout(res, 2000)); + try { + const r = await fetch('/api/mcp/servers', { credentials: 'same-origin' }); + if (!r.ok) throw new Error('HTTP ' + r.status); + const list = await r.json(); + fails = 0; + const s = Array.isArray(list) ? list.find(x => x.id === id) : null; + if (!s) continue; + if (s.auth_url) openAuth(s.auth_url); + if (s.status === 'connected') { + if (msg) msg.textContent = `Connected (${s.tool_count || 0} tools)`; + await renderList(); return; + } + if (s.status === 'error') { + if (msg) msg.textContent = `Failed: ${s.error || 'unknown'}`; return; + } + } catch (e) { + // Tolerate a single blip, but surface persistent failures instead of + // silently polling until timeout. + if (++fails >= 5 && msg) msg.textContent = `Status check failing (${e.message || 'network error'}) — still retrying…`; + } + } + if (msg) msg.textContent = 'Authorization timed out. Reconnect from the server list to retry.'; + } if (editId && editId !== 'new') { // Show management view for existing server formEl.innerHTML = '
Loading...
'; @@ -4525,7 +4587,7 @@ async function initUnifiedIntegrations() {

Add MCP Server

-
+
@@ -4538,9 +4600,12 @@ async function initUnifiedIntegrations() {
`; el('uf-mcp-transport').addEventListener('change', () => { - const sse = el('uf-mcp-transport').value === 'sse'; - el('uf-mcp-stdio-fields').style.display = sse ? 'none' : 'flex'; - el('uf-mcp-sse-fields').style.display = sse ? 'flex' : 'none'; + const v = el('uf-mcp-transport').value; + const isUrl = (v === 'sse' || v === 'http'); + el('uf-mcp-stdio-fields').style.display = isUrl ? 'none' : 'flex'; + el('uf-mcp-sse-fields').style.display = isUrl ? 'flex' : 'none'; + const urlInput = el('uf-mcp-url'); + if (urlInput) urlInput.placeholder = (v === 'http') ? 'https://mcp.example.com/mcp' : 'http://localhost:3001/sse'; }); el('uf-mcp-cancel').addEventListener('click', () => { formEl.style.display = 'none'; }); el('uf-mcp-save').addEventListener('click', async () => { @@ -4558,14 +4623,25 @@ async function initUnifiedIntegrations() { } else { fd.append('url', el('uf-mcp-url').value); } + const saveBtn = el('uf-mcp-save'), cancelBtn = el('uf-mcp-cancel'); + const _origLabel = saveBtn.textContent; + _setBtnLoading(saveBtn, true, 'Saving…'); if (cancelBtn) cancelBtn.disabled = true; try { const r = await fetch('/api/mcp/servers', { method: 'POST', credentials: 'same-origin', body: fd }); - if (r.ok) { + const data = await r.json().catch(() => ({})); + if (r.ok && data.needs_auth) { + el('uf-mcp-msg').textContent = 'Preparing authorization…'; + _handleMcpAuth(data.id, data.auth_url); + } else if (r.ok && (data.connected || data.status === 'connected')) { + el('uf-mcp-msg').textContent = `Connected (${data.tool_count || 0} tools)`; + formEl.style.display = 'none'; await renderList(); + } else if (r.ok) { el('uf-mcp-msg').textContent = 'Saved'; formEl.style.display = 'none'; await renderList(); } else { el('uf-mcp-msg').textContent = `Failed (${r.status})`; } } catch (_) { el('uf-mcp-msg').textContent = 'Failed'; } + finally { _setBtnLoading(saveBtn, false, _origLabel); if (cancelBtn) cancelBtn.disabled = false; } }); } } diff --git a/tests/test_mcp_manager.py b/tests/test_mcp_manager.py index 20a3bc3..a879f95 100644 --- a/tests/test_mcp_manager.py +++ b/tests/test_mcp_manager.py @@ -1,4 +1,7 @@ -from src.mcp_manager import _format_mcp_connection_error +import asyncio +from unittest.mock import patch + +from src.mcp_manager import _format_mcp_connection_error, McpManager def test_playwright_mcp_connection_error_includes_install_hint(): @@ -24,3 +27,15 @@ def test_generic_mcp_connection_error_preserves_original_error(): ) assert msg == "boom" + + +def test_http_transport_routes_to_start_http_connect(): + mgr = McpManager() + + async def fake_start(server_id, name, url): + return "ROUTED" + + with patch.object(McpManager, "_start_http_connect", side_effect=fake_start) as m: + result = asyncio.run(mgr.connect_server("id1", "n", "http", url="https://x/mcp")) + assert result == "ROUTED" + m.assert_called_once() diff --git a/tests/test_mcp_oauth.py b/tests/test_mcp_oauth.py new file mode 100644 index 0000000..a9f5fdf --- /dev/null +++ b/tests/test_mcp_oauth.py @@ -0,0 +1,81 @@ +import asyncio +from src import mcp_oauth + + +def test_registry_resolve_returns_code_and_state(): + async def go(): + fut = mcp_oauth.register_pending("st-1") + assert mcp_oauth.resolve_pending("st-1", "the-code") is True + return await asyncio.wait_for(fut, timeout=1) + code, state = asyncio.run(go()) + assert code == "the-code" + assert state == "st-1" + + +def test_resolve_unknown_state_is_false(): + assert mcp_oauth.resolve_pending("nope", "x") is False + + +def test_register_pending_prunes_abandoned_flows(): + import time as _t + + async def go(): + mcp_oauth._pending.clear() + mcp_oauth._pending_ts.clear() + old = mcp_oauth.register_pending("old-state") + # Backdate the entry past the authorization window. + mcp_oauth._pending_ts["old-state"] = _t.monotonic() - (mcp_oauth.AUTH_WAIT_SECONDS + 1) + # A new registration triggers a prune of the stale one. + mcp_oauth.register_pending("new-state") + return old + + old = asyncio.run(go()) + assert "old-state" not in mcp_oauth._pending + assert "old-state" not in mcp_oauth._pending_ts + assert "new-state" in mcp_oauth._pending + assert old.cancelled() + + +def test_build_provider_has_odysseus_client_metadata(): + p = mcp_oauth.build_provider("srv-1", "https://example.com/mcp") + md = p.context.client_metadata + assert md.client_name == "Odysseus" + assert "authorization_code" in md.grant_types + assert "refresh_token" in md.grant_types + assert str(md.redirect_uris[0]).rstrip("/") == mcp_oauth.REDIRECT_URI.rstrip("/") + + +def test_db_token_storage_round_trip(): + from mcp.shared.auth import OAuthToken + + class FakeSrv: + oauth_tokens = None + + srv = FakeSrv() + + class FakeQuery: + def filter(self, *a): + return self + + def first(self): + return srv + + class FakeSession: + def query(self, *a): + return FakeQuery() + + def commit(self): + pass + + def close(self): + pass + + storage = mcp_oauth.DbTokenStorage("srv-1", session_factory=lambda: FakeSession()) + + async def go(): + await storage.set_tokens(OAuthToken(access_token="abc", token_type="Bearer")) + return await storage.get_tokens() + + t = asyncio.run(go()) + assert t.access_token == "abc" + assert srv.oauth_tokens is not None # persisted as JSON