Skip to content

Commit f2840fe

Browse files
authored
Test auth (#609)
1 parent af4221f commit f2840fe

File tree

4 files changed

+36
-194
lines changed

4 files changed

+36
-194
lines changed

src/mcp/server/fastmcp/server.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from pydantic import BaseModel, Field
1818
from pydantic.networks import AnyUrl
1919
from pydantic_settings import BaseSettings, SettingsConfigDict
20-
from sse_starlette import EventSourceResponse
2120
from starlette.applications import Starlette
2221
from starlette.middleware import Middleware
2322
from starlette.middleware.authentication import AuthenticationMiddleware
2423
from starlette.requests import Request
2524
from starlette.responses import Response
26-
from starlette.routing import Mount, Route, request_response # type: ignore
25+
from starlette.routing import Mount, Route
26+
from starlette.types import Receive, Scope, Send
2727

2828
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
2929
from mcp.server.auth.middleware.bearer_auth import (
@@ -576,20 +576,19 @@ def sse_app(self) -> Starlette:
576576

577577
sse = SseServerTransport(self.settings.message_path)
578578

579-
async def handle_sse(request: Request) -> EventSourceResponse:
579+
async def handle_sse(scope: Scope, receive: Receive, send: Send):
580580
# Add client ID from auth context into request context if available
581581

582582
async with sse.connect_sse(
583-
request.scope,
584-
request.receive,
585-
request._send, # type: ignore[reportPrivateUsage]
583+
scope,
584+
receive,
585+
send,
586586
) as streams:
587587
await self._mcp_server.run(
588588
streams[0],
589589
streams[1],
590590
self._mcp_server.create_initialization_options(),
591591
)
592-
return streams[2]
593592

594593
# Create routes
595594
routes: list[Route | Mount] = []
@@ -629,7 +628,7 @@ async def handle_sse(request: Request) -> EventSourceResponse:
629628
Route(
630629
self.settings.sse_path,
631630
endpoint=RequireAuthMiddleware(
632-
request_response(handle_sse), required_scopes
631+
handle_sse, required_scopes
633632
),
634633
methods=["GET"],
635634
)

src/mcp/server/sse.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,15 @@ async def sse_writer():
120120
}
121121
)
122122

123-
# Ensure all streams are properly closed
124-
async with read_stream, write_stream, read_stream_writer, sse_stream_reader:
125-
async with anyio.create_task_group() as tg:
126-
response = EventSourceResponse(
127-
content=sse_stream_reader, data_sender_callable=sse_writer
128-
)
129-
logger.debug("Starting SSE response task")
130-
tg.start_soon(response, scope, receive, send)
131-
132-
logger.debug("Yielding read and write streams")
133-
yield (read_stream, write_stream, response)
123+
async with anyio.create_task_group() as tg:
124+
response = EventSourceResponse(
125+
content=sse_stream_reader, data_sender_callable=sse_writer
126+
)
127+
logger.debug("Starting SSE response task")
128+
tg.start_soon(response, scope, receive, send)
129+
130+
logger.debug("Yielding read and write streams")
131+
yield (read_stream, write_stream)
134132

135133
async def handle_post_message(
136134
self, scope: Scope, receive: Receive, send: Send
@@ -175,3 +173,4 @@ async def handle_post_message(
175173
response = Response("Accepted", status_code=202)
176174
await response(scope, receive, send)
177175
await writer.send(message)
176+

tests/server/fastmcp/auth/test_auth_integration.py

+4-166
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44

55
import base64
66
import hashlib
7-
import json
87
import secrets
98
import time
109
import unittest.mock
1110
from urllib.parse import parse_qs, urlparse
1211

13-
import anyio
1412
import httpx
1513
import pytest
16-
from httpx_sse import aconnect_sse
1714
from pydantic import AnyHttpUrl
1815
from starlette.applications import Starlette
1916

@@ -30,14 +27,10 @@
3027
RevocationOptions,
3128
create_auth_routes,
3229
)
33-
from mcp.server.auth.settings import AuthSettings
34-
from mcp.server.fastmcp import FastMCP
35-
from mcp.server.streaming_asgi_transport import StreamingASGITransport
3630
from mcp.shared.auth import (
3731
OAuthClientInformationFull,
3832
OAuthToken,
3933
)
40-
from mcp.types import JSONRPCRequest
4134

4235

4336
# Mock OAuth provider for testing
@@ -230,10 +223,11 @@ def auth_app(mock_oauth_provider):
230223

231224

232225
@pytest.fixture
233-
def test_client(auth_app) -> httpx.AsyncClient:
234-
return httpx.AsyncClient(
226+
async def test_client(auth_app):
227+
async with httpx.AsyncClient(
235228
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
236-
)
229+
) as client:
230+
yield client
237231

238232

239233
@pytest.fixture
@@ -993,163 +987,7 @@ async def test_client_registration_invalid_grant_type(
993987
)
994988

995989

996-
class TestFastMCPWithAuth:
997-
"""Test FastMCP server with authentication."""
998-
999-
@pytest.mark.anyio
1000-
async def test_fastmcp_with_auth(
1001-
self, mock_oauth_provider: MockOAuthProvider, pkce_challenge
1002-
):
1003-
"""Test creating a FastMCP server with authentication."""
1004-
# Create FastMCP server with auth provider
1005-
mcp = FastMCP(
1006-
auth_server_provider=mock_oauth_provider,
1007-
require_auth=True,
1008-
auth=AuthSettings(
1009-
issuer_url=AnyHttpUrl("https://auth.example.com"),
1010-
client_registration_options=ClientRegistrationOptions(enabled=True),
1011-
revocation_options=RevocationOptions(enabled=True),
1012-
required_scopes=["read", "write"],
1013-
),
1014-
)
1015-
1016-
# Add a test tool
1017-
@mcp.tool()
1018-
def test_tool(x: int) -> str:
1019-
return f"Result: {x}"
1020-
1021-
async with anyio.create_task_group() as task_group:
1022-
transport = StreamingASGITransport(
1023-
app=mcp.sse_app(),
1024-
task_group=task_group,
1025-
)
1026-
test_client = httpx.AsyncClient(
1027-
transport=transport, base_url="http://mcptest.com"
1028-
)
1029-
1030-
# Test metadata endpoint
1031-
response = await test_client.get("/.well-known/oauth-authorization-server")
1032-
assert response.status_code == 200
1033990

1034-
# Test that auth is required for protected endpoints
1035-
response = await test_client.get("/sse")
1036-
assert response.status_code == 401
1037-
1038-
response = await test_client.post("/messages/")
1039-
assert response.status_code == 401, response.content
1040-
1041-
response = await test_client.post(
1042-
"/messages/",
1043-
headers={"Authorization": "invalid"},
1044-
)
1045-
assert response.status_code == 401
1046-
1047-
response = await test_client.post(
1048-
"/messages/",
1049-
headers={"Authorization": "Bearer invalid"},
1050-
)
1051-
assert response.status_code == 401
1052-
1053-
# now, become authenticated and try to go through the flow again
1054-
client_metadata = {
1055-
"redirect_uris": ["https://client.example.com/callback"],
1056-
"client_name": "Test Client",
1057-
}
1058-
1059-
response = await test_client.post(
1060-
"/register",
1061-
json=client_metadata,
1062-
)
1063-
assert response.status_code == 201
1064-
client_info = response.json()
1065-
1066-
# Request authorization using POST with form-encoded data
1067-
response = await test_client.post(
1068-
"/authorize",
1069-
data={
1070-
"response_type": "code",
1071-
"client_id": client_info["client_id"],
1072-
"redirect_uri": "https://client.example.com/callback",
1073-
"code_challenge": pkce_challenge["code_challenge"],
1074-
"code_challenge_method": "S256",
1075-
"state": "test_state",
1076-
},
1077-
)
1078-
assert response.status_code == 302
1079-
1080-
# Extract the authorization code from the redirect URL
1081-
redirect_url = response.headers["location"]
1082-
parsed_url = urlparse(redirect_url)
1083-
query_params = parse_qs(parsed_url.query)
1084-
1085-
assert "code" in query_params
1086-
auth_code = query_params["code"][0]
1087-
1088-
# Exchange the authorization code for tokens
1089-
response = await test_client.post(
1090-
"/token",
1091-
data={
1092-
"grant_type": "authorization_code",
1093-
"client_id": client_info["client_id"],
1094-
"client_secret": client_info["client_secret"],
1095-
"code": auth_code,
1096-
"code_verifier": pkce_challenge["code_verifier"],
1097-
"redirect_uri": "https://client.example.com/callback",
1098-
},
1099-
)
1100-
assert response.status_code == 200
1101-
1102-
token_response = response.json()
1103-
assert "access_token" in token_response
1104-
authorization = f"Bearer {token_response['access_token']}"
1105-
1106-
# Test the authenticated endpoint with valid token
1107-
async with aconnect_sse(
1108-
test_client, "GET", "/sse", headers={"Authorization": authorization}
1109-
) as event_source:
1110-
assert event_source.response.status_code == 200
1111-
events = event_source.aiter_sse()
1112-
sse = await events.__anext__()
1113-
assert sse.event == "endpoint"
1114-
assert sse.data.startswith("/messages/?session_id=")
1115-
messages_uri = sse.data
1116-
1117-
# verify that we can now post to the /messages endpoint,
1118-
# and get a response on the /sse endpoint
1119-
response = await test_client.post(
1120-
messages_uri,
1121-
headers={"Authorization": authorization},
1122-
content=JSONRPCRequest(
1123-
jsonrpc="2.0",
1124-
id="123",
1125-
method="initialize",
1126-
params={
1127-
"protocolVersion": "2024-11-05",
1128-
"capabilities": {
1129-
"roots": {"listChanged": True},
1130-
"sampling": {},
1131-
},
1132-
"clientInfo": {"name": "ExampleClient", "version": "1.0.0"},
1133-
},
1134-
).model_dump_json(),
1135-
)
1136-
assert response.status_code == 202
1137-
assert response.content == b"Accepted"
1138-
1139-
sse = await events.__anext__()
1140-
assert sse.event == "message"
1141-
sse_data = json.loads(sse.data)
1142-
assert sse_data["id"] == "123"
1143-
assert set(sse_data["result"]["capabilities"].keys()) == {
1144-
"experimental",
1145-
"prompts",
1146-
"resources",
1147-
"tools",
1148-
}
1149-
# the /sse endpoint will never finish; normally, the client could just
1150-
# disconnect, but in tests the easiest way to do this is to cancel the
1151-
# task group
1152-
task_group.cancel_scope.cancel()
1153991

1154992

1155993
class TestAuthorizeEndpointErrors:

tests/shared/test_sse.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,20 @@ def server(server_port: int) -> Generator[None, None, None]:
150150
print("server process failed to terminate")
151151

152152

153+
@pytest.fixture()
154+
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
155+
"""Create test client"""
156+
async with httpx.AsyncClient(base_url=server_url) as client:
157+
yield client
158+
153159

160+
# Tests
154161
@pytest.mark.anyio
155-
@pytest.mark.skip(
156-
"fails in CI, but works locally. Need to investigate why."
157-
)
158-
async def test_raw_sse_connection(server, server_url) -> None:
162+
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
159163
"""Test the SSE connection establishment simply with an HTTP client."""
160-
try:
161-
async with httpx.AsyncClient(base_url=server_url) as http_client:
164+
async with anyio.create_task_group():
165+
166+
async def connection_test() -> None:
162167
async with http_client.stream("GET", "/sse") as response:
163168
assert response.status_code == 200
164169
assert (
@@ -176,8 +181,9 @@ async def test_raw_sse_connection(server, server_url) -> None:
176181
return
177182
line_number += 1
178183

179-
except Exception as e:
180-
pytest.fail(f"{e}")
184+
# Add timeout to prevent test from hanging if it fails
185+
with anyio.fail_after(3):
186+
await connection_test()
181187

182188

183189
@pytest.mark.anyio
@@ -243,4 +249,4 @@ async def test_sse_client_timeout(
243249
# we should receive an error here
244250
return
245251

246-
pytest.fail("the client should have timed out and returned an error already")
252+
pytest.fail("the client should have timed out and returned an error already")

0 commit comments

Comments
 (0)