feat(mcp): add Streamable HTTP transport with OAuth 2.0 (#1033)
* feat(mcp): add Streamable HTTP transport with OAuth 2.0 Odysseus could only reach MCP servers over stdio and SSE, so modern remote servers like https://mcp.higgsfield.ai/mcp (Streamable HTTP, gated behind OAuth) could not be connected. Add an `http` transport that connects via the SDK's streamablehttp_client and authenticates with the SDK's OAuthClientProvider: RFC 9728 protected-resource discovery, RFC 8414 authorization-server metadata, Dynamic Client Registration, authorization-code + PKCE, and token refresh. A small bridge (src/mcp_oauth.py) connects the SDK's blocking callback to the existing web callback route via an asyncio.Future keyed by the OAuth `state`, and the dynamic client registration plus tokens persist per-server in a new encrypted `oauth_tokens` column. The connect runs as a bounded background task so the "Add server" request returns immediately; redirect_handler publishes needs_auth + auth_url to connection state as soon as discovery/DCR completes (which can exceed the bounded wait), and the UI polls until connected. Remote users finish via the existing paste-back flow. The Google OAuth path is left unchanged. - core/database.py: encrypted oauth_tokens column + migration - src/mcp_oauth.py: OAuth provider, DB-backed TokenStorage, state registry - src/mcp_manager.py: http dispatch, background connect, _connect_http - routes/mcp_routes.py: http validation, needs_auth/auth_url, callback bridge - static/js/settings.js: Streamable HTTP option + OAuth flow with polling - tests: 5 new unit tests (transport dispatch, registry, token storage) Verified against the live Higgsfield server: discovery, DCR (client_id issued), loopback redirect accepted, and a PKCE authorization URL with needs_auth status. No regressions (full suite delta is only the 5 added passing tests). * fix(mcp): address PR #1033 review feedback - mcp_oauth: derive redirect URI from OAUTH_REDIRECT_BASE_URL/APP_PUBLIC_URL (default http://localhost:7000) instead of hardcoding the port - mcp_oauth: leave OAuth scope unset so the SDK derives it from the server's WWW-Authenticate/protected-resource metadata; hardcoding an OIDC scope broke non-OpenID MCP servers (verified: Higgsfield still gets its server-derived scope) - mcp_oauth: prune abandoned OAuth flows (_prune_stale + _pending_ts) so the module-level registries can't grow unbounded - mcp_oauth: persist tokens/client-info in a single DB session/commit (_update) instead of a load+save double round-trip - mcp_manager: cancel and drop the background connect task in disconnect_server so a deleted server stops publishing status - database: document why the oauth_tokens migration uses TEXT while the model declares EncryptedText (encryption is applied at the Python layer) - settings.js: surface persistent OAuth-poll failures and an explicit timeout message instead of silently swallowing errors - tests: cover the stale-flow pruning * static/js/settings.js now shows an in-flight loading state on the buttons that fire requests:
This commit is contained in:
committed by
GitHub
parent
85334e8f3d
commit
1d80bf5e65
@@ -1,4 +1,7 @@
|
||||
from src.mcp_manager import _format_mcp_connection_error
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from src.mcp_manager import _format_mcp_connection_error, McpManager
|
||||
|
||||
|
||||
def test_playwright_mcp_connection_error_includes_install_hint():
|
||||
@@ -24,3 +27,15 @@ def test_generic_mcp_connection_error_preserves_original_error():
|
||||
)
|
||||
|
||||
assert msg == "boom"
|
||||
|
||||
|
||||
def test_http_transport_routes_to_start_http_connect():
|
||||
mgr = McpManager()
|
||||
|
||||
async def fake_start(server_id, name, url):
|
||||
return "ROUTED"
|
||||
|
||||
with patch.object(McpManager, "_start_http_connect", side_effect=fake_start) as m:
|
||||
result = asyncio.run(mgr.connect_server("id1", "n", "http", url="https://x/mcp"))
|
||||
assert result == "ROUTED"
|
||||
m.assert_called_once()
|
||||
|
||||
81
tests/test_mcp_oauth.py
Normal file
81
tests/test_mcp_oauth.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import asyncio
|
||||
from src import mcp_oauth
|
||||
|
||||
|
||||
def test_registry_resolve_returns_code_and_state():
|
||||
async def go():
|
||||
fut = mcp_oauth.register_pending("st-1")
|
||||
assert mcp_oauth.resolve_pending("st-1", "the-code") is True
|
||||
return await asyncio.wait_for(fut, timeout=1)
|
||||
code, state = asyncio.run(go())
|
||||
assert code == "the-code"
|
||||
assert state == "st-1"
|
||||
|
||||
|
||||
def test_resolve_unknown_state_is_false():
|
||||
assert mcp_oauth.resolve_pending("nope", "x") is False
|
||||
|
||||
|
||||
def test_register_pending_prunes_abandoned_flows():
|
||||
import time as _t
|
||||
|
||||
async def go():
|
||||
mcp_oauth._pending.clear()
|
||||
mcp_oauth._pending_ts.clear()
|
||||
old = mcp_oauth.register_pending("old-state")
|
||||
# Backdate the entry past the authorization window.
|
||||
mcp_oauth._pending_ts["old-state"] = _t.monotonic() - (mcp_oauth.AUTH_WAIT_SECONDS + 1)
|
||||
# A new registration triggers a prune of the stale one.
|
||||
mcp_oauth.register_pending("new-state")
|
||||
return old
|
||||
|
||||
old = asyncio.run(go())
|
||||
assert "old-state" not in mcp_oauth._pending
|
||||
assert "old-state" not in mcp_oauth._pending_ts
|
||||
assert "new-state" in mcp_oauth._pending
|
||||
assert old.cancelled()
|
||||
|
||||
|
||||
def test_build_provider_has_odysseus_client_metadata():
|
||||
p = mcp_oauth.build_provider("srv-1", "https://example.com/mcp")
|
||||
md = p.context.client_metadata
|
||||
assert md.client_name == "Odysseus"
|
||||
assert "authorization_code" in md.grant_types
|
||||
assert "refresh_token" in md.grant_types
|
||||
assert str(md.redirect_uris[0]).rstrip("/") == mcp_oauth.REDIRECT_URI.rstrip("/")
|
||||
|
||||
|
||||
def test_db_token_storage_round_trip():
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
class FakeSrv:
|
||||
oauth_tokens = None
|
||||
|
||||
srv = FakeSrv()
|
||||
|
||||
class FakeQuery:
|
||||
def filter(self, *a):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return srv
|
||||
|
||||
class FakeSession:
|
||||
def query(self, *a):
|
||||
return FakeQuery()
|
||||
|
||||
def commit(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
storage = mcp_oauth.DbTokenStorage("srv-1", session_factory=lambda: FakeSession())
|
||||
|
||||
async def go():
|
||||
await storage.set_tokens(OAuthToken(access_token="abc", token_type="Bearer"))
|
||||
return await storage.get_tokens()
|
||||
|
||||
t = asyncio.run(go())
|
||||
assert t.access_token == "abc"
|
||||
assert srv.oauth_tokens is not None # persisted as JSON
|
||||
Reference in New Issue
Block a user