feat(mcp): add Streamable HTTP transport with OAuth 2.0 (#1033)
* feat(mcp): add Streamable HTTP transport with OAuth 2.0 Odysseus could only reach MCP servers over stdio and SSE, so modern remote servers like https://mcp.higgsfield.ai/mcp (Streamable HTTP, gated behind OAuth) could not be connected. Add an `http` transport that connects via the SDK's streamablehttp_client and authenticates with the SDK's OAuthClientProvider: RFC 9728 protected-resource discovery, RFC 8414 authorization-server metadata, Dynamic Client Registration, authorization-code + PKCE, and token refresh. A small bridge (src/mcp_oauth.py) connects the SDK's blocking callback to the existing web callback route via an asyncio.Future keyed by the OAuth `state`, and the dynamic client registration plus tokens persist per-server in a new encrypted `oauth_tokens` column. The connect runs as a bounded background task so the "Add server" request returns immediately; redirect_handler publishes needs_auth + auth_url to connection state as soon as discovery/DCR completes (which can exceed the bounded wait), and the UI polls until connected. Remote users finish via the existing paste-back flow. The Google OAuth path is left unchanged. - core/database.py: encrypted oauth_tokens column + migration - src/mcp_oauth.py: OAuth provider, DB-backed TokenStorage, state registry - src/mcp_manager.py: http dispatch, background connect, _connect_http - routes/mcp_routes.py: http validation, needs_auth/auth_url, callback bridge - static/js/settings.js: Streamable HTTP option + OAuth flow with polling - tests: 5 new unit tests (transport dispatch, registry, token storage) Verified against the live Higgsfield server: discovery, DCR (client_id issued), loopback redirect accepted, and a PKCE authorization URL with needs_auth status. No regressions (full suite delta is only the 5 added passing tests). * fix(mcp): address PR #1033 review feedback - mcp_oauth: derive redirect URI from OAUTH_REDIRECT_BASE_URL/APP_PUBLIC_URL (default http://localhost:7000) instead of hardcoding the port - mcp_oauth: leave OAuth scope unset so the SDK derives it from the server's WWW-Authenticate/protected-resource metadata; hardcoding an OIDC scope broke non-OpenID MCP servers (verified: Higgsfield still gets its server-derived scope) - mcp_oauth: prune abandoned OAuth flows (_prune_stale + _pending_ts) so the module-level registries can't grow unbounded - mcp_oauth: persist tokens/client-info in a single DB session/commit (_update) instead of a load+save double round-trip - mcp_manager: cancel and drop the background connect task in disconnect_server so a deleted server stops publishing status - database: document why the oauth_tokens migration uses TEXT while the model declares EncryptedText (encryption is applied at the Python layer) - settings.js: surface persistent OAuth-poll failures and an explicit timeout message instead of silently swallowing errors - tests: cover the stale-flow pruning * static/js/settings.js now shows an in-flight loading state on the buttons that fire requests:
This commit is contained in:
committed by
GitHub
parent
85334e8f3d
commit
1d80bf5e65
@@ -375,6 +375,7 @@ class McpServer(TimestampMixin, Base):
|
|||||||
is_enabled = Column(Boolean, default=True)
|
is_enabled = Column(Boolean, default=True)
|
||||||
oauth_config = Column(Text, nullable=True) # JSON: provider, keys_file, token_file, scopes
|
oauth_config = Column(Text, nullable=True) # JSON: provider, keys_file, token_file, scopes
|
||||||
disabled_tools = Column(Text, nullable=True) # JSON array of tool names to hide from LLM
|
disabled_tools = Column(Text, nullable=True) # JSON array of tool names to hide from LLM
|
||||||
|
oauth_tokens = Column(EncryptedText, nullable=True) # JSON {tokens, client_info} for generic MCP OAuth, encrypted at rest
|
||||||
|
|
||||||
|
|
||||||
class Comparison(TimestampMixin, Base):
|
class Comparison(TimestampMixin, Base):
|
||||||
@@ -1311,6 +1312,23 @@ def _migrate_add_disabled_tools():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"disabled_tools migration: {e}")
|
logging.getLogger(__name__).warning(f"disabled_tools migration: {e}")
|
||||||
|
|
||||||
|
def _migrate_add_mcp_oauth_tokens_column():
|
||||||
|
"""Add oauth_tokens column to mcp_servers table if missing.
|
||||||
|
|
||||||
|
The model declares this column as EncryptedText, but the SQL type is plain
|
||||||
|
TEXT on purpose: EncryptedText is a SQLAlchemy TypeDecorator that encrypts at
|
||||||
|
the Python layer and stores the ciphertext as TEXT, so the DB column type is
|
||||||
|
TEXT. This matches the existing encrypted columns (see _migrate_encrypt_*)."""
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
cols = [r[1] for r in conn.execute(text("PRAGMA table_info(mcp_servers)"))]
|
||||||
|
if "oauth_tokens" not in cols:
|
||||||
|
conn.execute(text("ALTER TABLE mcp_servers ADD COLUMN oauth_tokens TEXT"))
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Added oauth_tokens column to mcp_servers")
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"oauth_tokens migration: {e}")
|
||||||
|
|
||||||
def _migrate_add_task_v2_columns():
|
def _migrate_add_task_v2_columns():
|
||||||
"""Add cron_expression, then_task_id, webhook_token to scheduled_tasks."""
|
"""Add cron_expression, then_task_id, webhook_token to scheduled_tasks."""
|
||||||
new_cols = {
|
new_cols = {
|
||||||
@@ -1589,6 +1607,7 @@ def init_db():
|
|||||||
_migrate_add_oauth_config()
|
_migrate_add_oauth_config()
|
||||||
_migrate_add_task_automation_columns()
|
_migrate_add_task_automation_columns()
|
||||||
_migrate_add_disabled_tools()
|
_migrate_add_disabled_tools()
|
||||||
|
_migrate_add_mcp_oauth_tokens_column()
|
||||||
_migrate_add_task_v2_columns()
|
_migrate_add_task_v2_columns()
|
||||||
_migrate_add_notifications_enabled()
|
_migrate_add_notifications_enabled()
|
||||||
_migrate_drop_ping_notes_tasks()
|
_migrate_drop_ping_notes_tasks()
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
"disabled_tool_count": len(disabled_list),
|
"disabled_tool_count": len(disabled_list),
|
||||||
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
|
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
|
||||||
"error": status.get("error"),
|
"error": status.get("error"),
|
||||||
|
"auth_url": status.get("auth_url"),
|
||||||
"has_oauth": oauth_cfg is not None,
|
"has_oauth": oauth_cfg is not None,
|
||||||
"needs_oauth": needs_oauth,
|
"needs_oauth": needs_oauth,
|
||||||
})
|
})
|
||||||
@@ -171,6 +172,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
raise HTTPException(400, "command is required for stdio transport")
|
raise HTTPException(400, "command is required for stdio transport")
|
||||||
if transport == "sse" and not url:
|
if transport == "sse" and not url:
|
||||||
raise HTTPException(400, "url is required for SSE transport")
|
raise HTTPException(400, "url is required for SSE transport")
|
||||||
|
if transport == "http" and not url:
|
||||||
|
raise HTTPException(400, "url is required for HTTP transport")
|
||||||
|
|
||||||
# Parse JSON fields
|
# Parse JSON fields
|
||||||
try:
|
try:
|
||||||
@@ -262,6 +265,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
status = mcp_manager.get_server_status(server_id)
|
status = mcp_manager.get_server_status(server_id)
|
||||||
|
needs_auth = status.get("status") == "needs_auth"
|
||||||
return {
|
return {
|
||||||
"id": server_id,
|
"id": server_id,
|
||||||
"name": name,
|
"name": name,
|
||||||
@@ -270,6 +274,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
"tool_count": status.get("tool_count", 0),
|
"tool_count": status.get("tool_count", 0),
|
||||||
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
|
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
|
||||||
"needs_oauth": needs_oauth,
|
"needs_oauth": needs_oauth,
|
||||||
|
"needs_auth": needs_auth,
|
||||||
|
"auth_url": status.get("auth_url"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/servers/{server_id}/reconnect")
|
@router.post("/servers/{server_id}/reconnect")
|
||||||
@@ -302,6 +308,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
"status": status.get("status", "disconnected"),
|
"status": status.get("status", "disconnected"),
|
||||||
"tool_count": status.get("tool_count", 0),
|
"tool_count": status.get("tool_count", 0),
|
||||||
"error": status.get("error"),
|
"error": status.get("error"),
|
||||||
|
"auth_url": status.get("auth_url"),
|
||||||
|
"needs_auth": status.get("status") == "needs_auth",
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -467,10 +475,18 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
|
|
||||||
@router.get("/oauth/callback")
|
@router.get("/oauth/callback")
|
||||||
async def oauth_callback(code: str, state: str, request: Request):
|
async def oauth_callback(code: str, state: str, request: Request):
|
||||||
"""Handle OAuth callback from Google — exchange code for tokens."""
|
"""Handle OAuth callback. Generic MCP OAuth flows resolve via the
|
||||||
|
pending-state registry; Google flows fall through to the legacy path."""
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
server_id = state
|
from src.mcp_oauth import resolve_pending
|
||||||
return await _exchange_and_connect(server_id, code, request)
|
if resolve_pending(state, code):
|
||||||
|
return HTMLResponse(_oauth_result_page(
|
||||||
|
"Authorization Successful",
|
||||||
|
"The MCP server is connecting. You can close this window and return to Odysseus.",
|
||||||
|
success=True,
|
||||||
|
))
|
||||||
|
# Legacy Google path: state is the server_id
|
||||||
|
return await _exchange_and_connect(state, code, request)
|
||||||
|
|
||||||
@router.post("/oauth/exchange/{server_id}")
|
@router.post("/oauth/exchange/{server_id}")
|
||||||
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
|
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
|
||||||
@@ -485,6 +501,17 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
|
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
|
||||||
|
|
||||||
|
# Generic MCP OAuth: if the pasted URL carries a state we are waiting on,
|
||||||
|
# resolve it directly (the background connect finishes the handshake).
|
||||||
|
state = params.get("state", [None])[0]
|
||||||
|
from src.mcp_oauth import resolve_pending
|
||||||
|
if state and resolve_pending(state, code):
|
||||||
|
return HTMLResponse(_oauth_result_page(
|
||||||
|
"Authorization Successful",
|
||||||
|
"The MCP server is connecting. You can close this window and return to Odysseus.",
|
||||||
|
success=True,
|
||||||
|
))
|
||||||
|
|
||||||
return await _exchange_and_connect(server_id, code, request)
|
return await _exchange_and_connect(server_id, code, request)
|
||||||
|
|
||||||
async def _exchange_and_connect(server_id: str, code: str, request: Request):
|
async def _exchange_and_connect(server_id: str, code: str, request: Request):
|
||||||
|
|||||||
@@ -70,7 +70,9 @@ class McpManager:
|
|||||||
self._sessions: Dict[str, Any] = {}
|
self._sessions: Dict[str, Any] = {}
|
||||||
# server_id -> exit stack (for cleanup)
|
# server_id -> exit stack (for cleanup)
|
||||||
self._stacks: Dict[str, Any] = {}
|
self._stacks: Dict[str, Any] = {}
|
||||||
# Tracking updates to tools/connections for RAG indexing
|
# server_id -> background connect task (HTTP transport / OAuth)
|
||||||
|
self._connect_tasks: Dict[str, Any] = {}
|
||||||
|
# Tracking updates to tools/connections for RAG indexing / prompt cache
|
||||||
self._generation = 0
|
self._generation = 0
|
||||||
|
|
||||||
async def connect_server(
|
async def connect_server(
|
||||||
@@ -83,12 +85,14 @@ class McpManager:
|
|||||||
env: Optional[Dict[str, str]] = None,
|
env: Optional[Dict[str, str]] = None,
|
||||||
url: Optional[str] = None,
|
url: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Connect to an MCP server via stdio or SSE transport."""
|
"""Connect to an MCP server via stdio, SSE, or Streamable HTTP transport."""
|
||||||
try:
|
try:
|
||||||
if transport == "stdio":
|
if transport == "stdio":
|
||||||
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
|
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
|
||||||
elif transport == "sse":
|
elif transport == "sse":
|
||||||
res = await self._connect_sse(server_id, name, url)
|
res = await self._connect_sse(server_id, name, url)
|
||||||
|
elif transport == "http":
|
||||||
|
res = await self._start_http_connect(server_id, name, url)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown MCP transport: {transport}")
|
logger.error(f"Unknown MCP transport: {transport}")
|
||||||
res = False
|
res = False
|
||||||
@@ -211,8 +215,101 @@ class McpManager:
|
|||||||
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
|
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool:
|
||||||
|
"""Begin a Streamable HTTP connect in the background. Returns within
|
||||||
|
`wait` seconds: True if it connected (cached-token path), otherwise the
|
||||||
|
flow is awaiting browser authorization and status becomes 'needs_auth'."""
|
||||||
|
import asyncio
|
||||||
|
self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"}
|
||||||
|
task = asyncio.create_task(self._connect_http(server_id, name, url))
|
||||||
|
self._connect_tasks[server_id] = task
|
||||||
|
done, _ = await asyncio.wait({task}, timeout=wait)
|
||||||
|
if task in done:
|
||||||
|
try:
|
||||||
|
return task.result()
|
||||||
|
except Exception as e:
|
||||||
|
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
|
||||||
|
return False
|
||||||
|
# Still running → either awaiting authorization, or discovery/DCR is
|
||||||
|
# still in flight. If _on_redirect already published needs_auth+auth_url,
|
||||||
|
# leave it; otherwise mark needs_auth (auth_url filled in once it fires).
|
||||||
|
from src.mcp_oauth import pop_auth_url
|
||||||
|
cur = self._connections.get(server_id, {})
|
||||||
|
if cur.get("status") != "needs_auth":
|
||||||
|
self._connections[server_id] = {
|
||||||
|
"status": "needs_auth", "name": name, "transport": "http",
|
||||||
|
"auth_url": pop_auth_url(server_id),
|
||||||
|
}
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _connect_http(self, server_id: str, name: str, url: str) -> bool:
|
||||||
|
"""Connect to a Streamable HTTP MCP server (with automatic OAuth)."""
|
||||||
|
try:
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from src.mcp_oauth import build_provider, clear_auth_url
|
||||||
|
|
||||||
|
def _on_redirect(auth_url):
|
||||||
|
# Publish needs_auth the moment the URL is known, independent of
|
||||||
|
# how long discovery/DCR took (may exceed the bounded start wait).
|
||||||
|
self._connections[server_id] = {
|
||||||
|
"status": "needs_auth", "name": name, "transport": "http",
|
||||||
|
"auth_url": auth_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider = build_provider(server_id, url, on_redirect=_on_redirect)
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider))
|
||||||
|
read_stream, write_stream, _get_session_id = transport
|
||||||
|
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
tools = []
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
tools.append({
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description or "",
|
||||||
|
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||||
|
})
|
||||||
|
|
||||||
|
self._sessions[server_id] = session
|
||||||
|
self._stacks[server_id] = stack
|
||||||
|
self._tools[server_id] = tools
|
||||||
|
self._connections[server_id] = {
|
||||||
|
"status": "connected", "name": name, "transport": "http",
|
||||||
|
"tool_count": len(tools),
|
||||||
|
}
|
||||||
|
clear_auth_url(server_id)
|
||||||
|
# Tools changed (this can complete after connect_server already
|
||||||
|
# returned, via the background OAuth flow), so bump the generation
|
||||||
|
# to invalidate the tool-prompt cache.
|
||||||
|
self._generation += 1
|
||||||
|
logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MCP package not installed. Install with: pip install mcp")
|
||||||
|
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}")
|
||||||
|
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
|
||||||
|
return False
|
||||||
|
|
||||||
async def disconnect_server(self, server_id: str):
|
async def disconnect_server(self, server_id: str):
|
||||||
"""Disconnect from an MCP server."""
|
"""Disconnect from an MCP server."""
|
||||||
|
# Cancel any in-flight HTTP/OAuth background connect so it stops
|
||||||
|
# publishing status for a server that may be getting deleted.
|
||||||
|
task = self._connect_tasks.pop(server_id, None)
|
||||||
|
if task is not None and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
from src.mcp_oauth import clear_auth_url
|
||||||
|
clear_auth_url(server_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
stack = self._stacks.pop(server_id, None)
|
stack = self._stacks.pop(server_id, None)
|
||||||
if stack:
|
if stack:
|
||||||
try:
|
try:
|
||||||
|
|||||||
193
src/mcp_oauth.py
Normal file
193
src/mcp_oauth.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers.
|
||||||
|
|
||||||
|
Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client
|
||||||
|
Registration, authorization-code + PKCE, token refresh) to Odysseus's web
|
||||||
|
callback route. Tokens and the dynamic registration persist per-server,
|
||||||
|
encrypted, so the interactive flow runs only once.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# OAuth redirect URI registered with every authorization server via DCR. Loopback
|
||||||
|
# is allowed for native/desktop clients (RFC 8252); remote users finish via the
|
||||||
|
# paste-back flow. Deployments not reachable at http://localhost:7000 (custom
|
||||||
|
# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or
|
||||||
|
# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back
|
||||||
|
# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host
|
||||||
|
# port-map; the app always listens on 7000 inside the container.
|
||||||
|
_REDIRECT_BASE = (
|
||||||
|
os.environ.get("OAUTH_REDIRECT_BASE_URL")
|
||||||
|
or os.environ.get("APP_PUBLIC_URL")
|
||||||
|
or "http://localhost:7000"
|
||||||
|
).rstrip("/")
|
||||||
|
REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback"
|
||||||
|
|
||||||
|
# How long the background connect waits for the user to authorize before giving up.
|
||||||
|
AUTH_WAIT_SECONDS = 300
|
||||||
|
|
||||||
|
_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)]
|
||||||
|
_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning
|
||||||
|
_auth_urls: Dict[str, str] = {} # server_id -> authorization URL
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_stale() -> None:
|
||||||
|
"""Drop abandoned flows whose authorization window has elapsed so the
|
||||||
|
module-level registries don't grow unbounded (e.g. a user who never
|
||||||
|
finishes the browser step)."""
|
||||||
|
now = time.monotonic()
|
||||||
|
for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]:
|
||||||
|
fut = _pending.pop(state, None)
|
||||||
|
_pending_ts.pop(state, None)
|
||||||
|
if fut is not None and not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
def _discard_pending(state: Optional[str]) -> None:
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
_pending.pop(state, None)
|
||||||
|
_pending_ts.pop(state, None)
|
||||||
|
|
||||||
|
|
||||||
|
def register_pending(state: str) -> asyncio.Future:
|
||||||
|
_prune_stale()
|
||||||
|
fut = asyncio.get_running_loop().create_future()
|
||||||
|
_pending[state] = fut
|
||||||
|
_pending_ts[state] = time.monotonic()
|
||||||
|
return fut
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_pending(state: str, code: str) -> bool:
|
||||||
|
fut = _pending.get(state)
|
||||||
|
if fut is not None and not fut.done():
|
||||||
|
fut.set_result((code, state))
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def pop_auth_url(server_id: str) -> Optional[str]:
|
||||||
|
return _auth_urls.get(server_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_auth_url(server_id: str) -> None:
|
||||||
|
_auth_urls.pop(server_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class DbTokenStorage:
|
||||||
|
"""SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column."""
|
||||||
|
|
||||||
|
def __init__(self, server_id: str, session_factory=None):
|
||||||
|
self.server_id = server_id
|
||||||
|
if session_factory is None:
|
||||||
|
from core.database import SessionLocal
|
||||||
|
session_factory = SessionLocal
|
||||||
|
self._sf = session_factory
|
||||||
|
|
||||||
|
def _load(self) -> dict:
|
||||||
|
from core.database import McpServer
|
||||||
|
db = self._sf()
|
||||||
|
try:
|
||||||
|
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
|
||||||
|
if srv and srv.oauth_tokens:
|
||||||
|
return json.loads(srv.oauth_tokens)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _update(self, key: str, value: dict) -> None:
|
||||||
|
"""Load, set one key, and persist the oauth_tokens JSON in a single
|
||||||
|
session/commit (avoids the load+save double round-trip per write)."""
|
||||||
|
from core.database import McpServer
|
||||||
|
db = self._sf()
|
||||||
|
try:
|
||||||
|
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
|
||||||
|
if srv is None:
|
||||||
|
return
|
||||||
|
data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {}
|
||||||
|
data[key] = value
|
||||||
|
srv.oauth_tokens = json.dumps(data)
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def get_tokens(self):
|
||||||
|
from mcp.shared.auth import OAuthToken
|
||||||
|
data = self._load().get("tokens")
|
||||||
|
return OAuthToken.model_validate(data) if data else None
|
||||||
|
|
||||||
|
async def set_tokens(self, tokens) -> None:
|
||||||
|
self._update("tokens", json.loads(tokens.model_dump_json()))
|
||||||
|
|
||||||
|
async def get_client_info(self):
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull
|
||||||
|
data = self._load().get("client_info")
|
||||||
|
return OAuthClientInformationFull.model_validate(data) if data else None
|
||||||
|
|
||||||
|
async def set_client_info(self, client_info) -> None:
|
||||||
|
self._update("client_info", json.loads(client_info.model_dump_json()))
|
||||||
|
|
||||||
|
|
||||||
|
def build_provider(server_id: str, url: str, on_redirect=None):
|
||||||
|
"""Construct an OAuthClientProvider that drives the browser flow via the
|
||||||
|
Odysseus callback route.
|
||||||
|
|
||||||
|
on_redirect(authorization_url): optional sync callback invoked the moment
|
||||||
|
the authorization URL is known (after discovery + DCR). The manager uses it
|
||||||
|
to publish 'needs_auth' + auth_url to connection state regardless of how
|
||||||
|
long discovery/DCR took.
|
||||||
|
"""
|
||||||
|
from mcp.client.auth import OAuthClientProvider
|
||||||
|
from mcp.shared.auth import OAuthClientMetadata
|
||||||
|
|
||||||
|
client_metadata = OAuthClientMetadata(
|
||||||
|
client_name="Odysseus",
|
||||||
|
redirect_uris=[REDIRECT_URI],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
# Leave scope unset: the SDK applies the MCP scope-selection strategy and
|
||||||
|
# overwrites this from the server's WWW-Authenticate / protected-resource
|
||||||
|
# metadata before building the auth URL. Hardcoding an OIDC scope here
|
||||||
|
# would break the many MCP servers that are not OpenID providers.
|
||||||
|
scope=None,
|
||||||
|
token_endpoint_auth_method="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def redirect_handler(authorization_url: str) -> None:
|
||||||
|
state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0]
|
||||||
|
if state:
|
||||||
|
register_pending(state)
|
||||||
|
_auth_urls[server_id] = authorization_url
|
||||||
|
if on_redirect is not None:
|
||||||
|
try:
|
||||||
|
on_redirect(authorization_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"MCP OAuth on_redirect callback failed: {e}")
|
||||||
|
logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})")
|
||||||
|
|
||||||
|
async def callback_handler() -> Tuple[str, Optional[str]]:
|
||||||
|
auth_url = _auth_urls.get(server_id)
|
||||||
|
state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None
|
||||||
|
fut = _pending.get(state)
|
||||||
|
if fut is None:
|
||||||
|
raise RuntimeError("No pending OAuth flow for this server")
|
||||||
|
try:
|
||||||
|
code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS)
|
||||||
|
return code, ret_state
|
||||||
|
finally:
|
||||||
|
_discard_pending(state)
|
||||||
|
_auth_urls.pop(server_id, None)
|
||||||
|
|
||||||
|
return OAuthClientProvider(
|
||||||
|
server_url=url,
|
||||||
|
client_metadata=client_metadata,
|
||||||
|
storage=DbTokenStorage(server_id),
|
||||||
|
redirect_handler=redirect_handler,
|
||||||
|
callback_handler=callback_handler,
|
||||||
|
)
|
||||||
@@ -4448,6 +4448,68 @@ async function initUnifiedIntegrations() {
|
|||||||
|
|
||||||
// ── MCP form — full management view ──
|
// ── MCP form — full management view ──
|
||||||
async function showMcpForm(editId) {
|
async function showMcpForm(editId) {
|
||||||
|
// Toggle an in-flight loading state on a button (disabled + dimmed + label).
|
||||||
|
function _setBtnLoading(btn, loading, label) {
|
||||||
|
if (!btn) return;
|
||||||
|
btn.disabled = loading;
|
||||||
|
btn.style.opacity = loading ? '0.6' : '';
|
||||||
|
btn.style.cursor = loading ? 'progress' : '';
|
||||||
|
if (label != null) btn.textContent = label;
|
||||||
|
}
|
||||||
|
function _showMcpPasteback(id) {
|
||||||
|
const msg = el('uf-mcp-msg'); if (!msg) return;
|
||||||
|
if (el('uf-mcp-pasteback')) return; // already shown
|
||||||
|
msg.innerHTML =
|
||||||
|
'Authorize in the opened tab. If the redirect fails (remote access), paste the resulting URL here: ' +
|
||||||
|
'<input id="uf-mcp-pasteback" class="settings-input" placeholder="http://localhost:7000/api/mcp/oauth/callback?code=..." style="margin-top:4px">' +
|
||||||
|
'<button class="admin-btn-sm" id="uf-mcp-paste-go" style="margin-top:4px">Submit</button>';
|
||||||
|
const pasteGo = el('uf-mcp-paste-go');
|
||||||
|
if (pasteGo) pasteGo.addEventListener('click', async () => {
|
||||||
|
const cb = el('uf-mcp-pasteback').value.trim();
|
||||||
|
if (!cb) return;
|
||||||
|
const pf = new FormData(); pf.append('callback_url', cb);
|
||||||
|
_setBtnLoading(pasteGo, true, 'Submitting…');
|
||||||
|
try {
|
||||||
|
await fetch(`/api/mcp/oauth/exchange/${id}`, { method: 'POST', credentials: 'same-origin', body: pf });
|
||||||
|
} finally {
|
||||||
|
_setBtnLoading(pasteGo, false, 'Submit');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drives the OAuth flow: waits for the auth_url (discovery+DCR may lag),
|
||||||
|
// opens it once, then resolves on connected/error.
|
||||||
|
async function _handleMcpAuth(id, initialAuthUrl, tries = 90) {
|
||||||
|
let opened = false;
|
||||||
|
const openAuth = (u) => { if (!opened && u) { opened = true; window.open(u, '_blank', 'noopener'); _showMcpPasteback(id); } };
|
||||||
|
openAuth(initialAuthUrl);
|
||||||
|
const msg = el('uf-mcp-msg');
|
||||||
|
let fails = 0;
|
||||||
|
for (let i = 0; i < tries; i++) {
|
||||||
|
await new Promise(res => setTimeout(res, 2000));
|
||||||
|
try {
|
||||||
|
const r = await fetch('/api/mcp/servers', { credentials: 'same-origin' });
|
||||||
|
if (!r.ok) throw new Error('HTTP ' + r.status);
|
||||||
|
const list = await r.json();
|
||||||
|
fails = 0;
|
||||||
|
const s = Array.isArray(list) ? list.find(x => x.id === id) : null;
|
||||||
|
if (!s) continue;
|
||||||
|
if (s.auth_url) openAuth(s.auth_url);
|
||||||
|
if (s.status === 'connected') {
|
||||||
|
if (msg) msg.textContent = `Connected (${s.tool_count || 0} tools)`;
|
||||||
|
await renderList(); return;
|
||||||
|
}
|
||||||
|
if (s.status === 'error') {
|
||||||
|
if (msg) msg.textContent = `Failed: ${s.error || 'unknown'}`; return;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// Tolerate a single blip, but surface persistent failures instead of
|
||||||
|
// silently polling until timeout.
|
||||||
|
if (++fails >= 5 && msg) msg.textContent = `Status check failing (${e.message || 'network error'}) — still retrying…`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (msg) msg.textContent = 'Authorization timed out. Reconnect from the server list to retry.';
|
||||||
|
}
|
||||||
if (editId && editId !== 'new') {
|
if (editId && editId !== 'new') {
|
||||||
// Show management view for existing server
|
// Show management view for existing server
|
||||||
formEl.innerHTML = '<div class="admin-card" style="margin-top:8px"><span style="opacity:0.5;font-size:11px">Loading...</span></div>';
|
formEl.innerHTML = '<div class="admin-card" style="margin-top:8px"><span style="opacity:0.5;font-size:11px">Loading...</span></div>';
|
||||||
@@ -4525,7 +4587,7 @@ async function initUnifiedIntegrations() {
|
|||||||
<h2 style="font-size:13px">Add MCP Server</h2>
|
<h2 style="font-size:13px">Add MCP Server</h2>
|
||||||
<div class="settings-col">
|
<div class="settings-col">
|
||||||
<div class="settings-row"><label class="settings-label">Name</label><input id="uf-mcp-name" class="settings-input" placeholder="Server name"></div>
|
<div class="settings-row"><label class="settings-label">Name</label><input id="uf-mcp-name" class="settings-input" placeholder="Server name"></div>
|
||||||
<div class="settings-row"><label class="settings-label">Transport</label><select id="uf-mcp-transport" class="settings-input"><option value="stdio">stdio</option><option value="sse">SSE</option></select></div>
|
<div class="settings-row"><label class="settings-label">Transport</label><select id="uf-mcp-transport" class="settings-input"><option value="stdio">stdio</option><option value="sse">SSE</option><option value="http">Streamable HTTP</option></select></div>
|
||||||
<div id="uf-mcp-stdio-fields" style="display:flex;flex-direction:column;gap:6px;">
|
<div id="uf-mcp-stdio-fields" style="display:flex;flex-direction:column;gap:6px;">
|
||||||
<div class="settings-row"><label class="settings-label">Command</label><input id="uf-mcp-cmd" class="settings-input" placeholder="npx"></div>
|
<div class="settings-row"><label class="settings-label">Command</label><input id="uf-mcp-cmd" class="settings-input" placeholder="npx"></div>
|
||||||
<div class="settings-row"><label class="settings-label">Args</label><input id="uf-mcp-args" class="settings-input" placeholder='["-y", "@modelcontextprotocol/server-filesystem"]'></div>
|
<div class="settings-row"><label class="settings-label">Args</label><input id="uf-mcp-args" class="settings-input" placeholder='["-y", "@modelcontextprotocol/server-filesystem"]'></div>
|
||||||
@@ -4538,9 +4600,12 @@ async function initUnifiedIntegrations() {
|
|||||||
</div>
|
</div>
|
||||||
</div>`;
|
</div>`;
|
||||||
el('uf-mcp-transport').addEventListener('change', () => {
|
el('uf-mcp-transport').addEventListener('change', () => {
|
||||||
const sse = el('uf-mcp-transport').value === 'sse';
|
const v = el('uf-mcp-transport').value;
|
||||||
el('uf-mcp-stdio-fields').style.display = sse ? 'none' : 'flex';
|
const isUrl = (v === 'sse' || v === 'http');
|
||||||
el('uf-mcp-sse-fields').style.display = sse ? 'flex' : 'none';
|
el('uf-mcp-stdio-fields').style.display = isUrl ? 'none' : 'flex';
|
||||||
|
el('uf-mcp-sse-fields').style.display = isUrl ? 'flex' : 'none';
|
||||||
|
const urlInput = el('uf-mcp-url');
|
||||||
|
if (urlInput) urlInput.placeholder = (v === 'http') ? 'https://mcp.example.com/mcp' : 'http://localhost:3001/sse';
|
||||||
});
|
});
|
||||||
el('uf-mcp-cancel').addEventListener('click', () => { formEl.style.display = 'none'; });
|
el('uf-mcp-cancel').addEventListener('click', () => { formEl.style.display = 'none'; });
|
||||||
el('uf-mcp-save').addEventListener('click', async () => {
|
el('uf-mcp-save').addEventListener('click', async () => {
|
||||||
@@ -4558,14 +4623,25 @@ async function initUnifiedIntegrations() {
|
|||||||
} else {
|
} else {
|
||||||
fd.append('url', el('uf-mcp-url').value);
|
fd.append('url', el('uf-mcp-url').value);
|
||||||
}
|
}
|
||||||
|
const saveBtn = el('uf-mcp-save'), cancelBtn = el('uf-mcp-cancel');
|
||||||
|
const _origLabel = saveBtn.textContent;
|
||||||
|
_setBtnLoading(saveBtn, true, 'Saving…'); if (cancelBtn) cancelBtn.disabled = true;
|
||||||
try {
|
try {
|
||||||
const r = await fetch('/api/mcp/servers', { method: 'POST', credentials: 'same-origin', body: fd });
|
const r = await fetch('/api/mcp/servers', { method: 'POST', credentials: 'same-origin', body: fd });
|
||||||
if (r.ok) {
|
const data = await r.json().catch(() => ({}));
|
||||||
|
if (r.ok && data.needs_auth) {
|
||||||
|
el('uf-mcp-msg').textContent = 'Preparing authorization…';
|
||||||
|
_handleMcpAuth(data.id, data.auth_url);
|
||||||
|
} else if (r.ok && (data.connected || data.status === 'connected')) {
|
||||||
|
el('uf-mcp-msg').textContent = `Connected (${data.tool_count || 0} tools)`;
|
||||||
|
formEl.style.display = 'none'; await renderList();
|
||||||
|
} else if (r.ok) {
|
||||||
el('uf-mcp-msg').textContent = 'Saved'; formEl.style.display = 'none'; await renderList();
|
el('uf-mcp-msg').textContent = 'Saved'; formEl.style.display = 'none'; await renderList();
|
||||||
} else {
|
} else {
|
||||||
el('uf-mcp-msg').textContent = `Failed (${r.status})`;
|
el('uf-mcp-msg').textContent = `Failed (${r.status})`;
|
||||||
}
|
}
|
||||||
} catch (_) { el('uf-mcp-msg').textContent = 'Failed'; }
|
} catch (_) { el('uf-mcp-msg').textContent = 'Failed'; }
|
||||||
|
finally { _setBtnLoading(saveBtn, false, _origLabel); if (cancelBtn) cancelBtn.disabled = false; }
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from src.mcp_manager import _format_mcp_connection_error
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from src.mcp_manager import _format_mcp_connection_error, McpManager
|
||||||
|
|
||||||
|
|
||||||
def test_playwright_mcp_connection_error_includes_install_hint():
|
def test_playwright_mcp_connection_error_includes_install_hint():
|
||||||
@@ -24,3 +27,15 @@ def test_generic_mcp_connection_error_preserves_original_error():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert msg == "boom"
|
assert msg == "boom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_transport_routes_to_start_http_connect():
|
||||||
|
mgr = McpManager()
|
||||||
|
|
||||||
|
async def fake_start(server_id, name, url):
|
||||||
|
return "ROUTED"
|
||||||
|
|
||||||
|
with patch.object(McpManager, "_start_http_connect", side_effect=fake_start) as m:
|
||||||
|
result = asyncio.run(mgr.connect_server("id1", "n", "http", url="https://x/mcp"))
|
||||||
|
assert result == "ROUTED"
|
||||||
|
m.assert_called_once()
|
||||||
|
|||||||
81
tests/test_mcp_oauth.py
Normal file
81
tests/test_mcp_oauth.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import asyncio
|
||||||
|
from src import mcp_oauth
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_resolve_returns_code_and_state():
|
||||||
|
async def go():
|
||||||
|
fut = mcp_oauth.register_pending("st-1")
|
||||||
|
assert mcp_oauth.resolve_pending("st-1", "the-code") is True
|
||||||
|
return await asyncio.wait_for(fut, timeout=1)
|
||||||
|
code, state = asyncio.run(go())
|
||||||
|
assert code == "the-code"
|
||||||
|
assert state == "st-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_unknown_state_is_false():
|
||||||
|
assert mcp_oauth.resolve_pending("nope", "x") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_pending_prunes_abandoned_flows():
|
||||||
|
import time as _t
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
mcp_oauth._pending.clear()
|
||||||
|
mcp_oauth._pending_ts.clear()
|
||||||
|
old = mcp_oauth.register_pending("old-state")
|
||||||
|
# Backdate the entry past the authorization window.
|
||||||
|
mcp_oauth._pending_ts["old-state"] = _t.monotonic() - (mcp_oauth.AUTH_WAIT_SECONDS + 1)
|
||||||
|
# A new registration triggers a prune of the stale one.
|
||||||
|
mcp_oauth.register_pending("new-state")
|
||||||
|
return old
|
||||||
|
|
||||||
|
old = asyncio.run(go())
|
||||||
|
assert "old-state" not in mcp_oauth._pending
|
||||||
|
assert "old-state" not in mcp_oauth._pending_ts
|
||||||
|
assert "new-state" in mcp_oauth._pending
|
||||||
|
assert old.cancelled()
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_provider_has_odysseus_client_metadata():
|
||||||
|
p = mcp_oauth.build_provider("srv-1", "https://example.com/mcp")
|
||||||
|
md = p.context.client_metadata
|
||||||
|
assert md.client_name == "Odysseus"
|
||||||
|
assert "authorization_code" in md.grant_types
|
||||||
|
assert "refresh_token" in md.grant_types
|
||||||
|
assert str(md.redirect_uris[0]).rstrip("/") == mcp_oauth.REDIRECT_URI.rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_token_storage_round_trip():
|
||||||
|
from mcp.shared.auth import OAuthToken
|
||||||
|
|
||||||
|
class FakeSrv:
|
||||||
|
oauth_tokens = None
|
||||||
|
|
||||||
|
srv = FakeSrv()
|
||||||
|
|
||||||
|
class FakeQuery:
|
||||||
|
def filter(self, *a):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return srv
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
def query(self, *a):
|
||||||
|
return FakeQuery()
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
storage = mcp_oauth.DbTokenStorage("srv-1", session_factory=lambda: FakeSession())
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
await storage.set_tokens(OAuthToken(access_token="abc", token_type="Bearer"))
|
||||||
|
return await storage.get_tokens()
|
||||||
|
|
||||||
|
t = asyncio.run(go())
|
||||||
|
assert t.access_token == "abc"
|
||||||
|
assert srv.oauth_tokens is not None # persisted as JSON
|
||||||
Reference in New Issue
Block a user