Skip to content

Test auth #609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from sse_starlette import EventSourceResponse
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route, request_response # type: ignore
from starlette.routing import Mount, Route
from starlette.types import Receive, Scope, Send

from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import (
Expand Down Expand Up @@ -576,20 +576,19 @@ def sse_app(self) -> Starlette:

sse = SseServerTransport(self.settings.message_path)

async def handle_sse(request: Request) -> EventSourceResponse:
async def handle_sse(scope: Scope, receive: Receive, send: Send):
# Add client ID from auth context into request context if available

async with sse.connect_sse(
request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
scope,
receive,
send,
) as streams:
await self._mcp_server.run(
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
)
return streams[2]

# Create routes
routes: list[Route | Mount] = []
Expand Down Expand Up @@ -629,7 +628,7 @@ async def handle_sse(request: Request) -> EventSourceResponse:
Route(
self.settings.sse_path,
endpoint=RequireAuthMiddleware(
request_response(handle_sse), required_scopes
handle_sse, required_scopes
),
methods=["GET"],
)
Expand Down
21 changes: 10 additions & 11 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,15 @@ async def sse_writer():
}
)

# Ensure all streams are properly closed
async with read_stream, write_stream, read_stream_writer, sse_stream_reader:
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream, response)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)

async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
Expand Down Expand Up @@ -175,3 +173,4 @@ async def handle_post_message(
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)

170 changes: 4 additions & 166 deletions tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

import base64
import hashlib
import json
import secrets
import time
import unittest.mock
from urllib.parse import parse_qs, urlparse

import anyio
import httpx
import pytest
from httpx_sse import aconnect_sse
from pydantic import AnyHttpUrl
from starlette.applications import Starlette

Expand All @@ -30,14 +27,10 @@
RevocationOptions,
create_auth_routes,
)
from mcp.server.auth.settings import AuthSettings
from mcp.server.fastmcp import FastMCP
from mcp.server.streaming_asgi_transport import StreamingASGITransport
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthToken,
)
from mcp.types import JSONRPCRequest


# Mock OAuth provider for testing
Expand Down Expand Up @@ -230,10 +223,11 @@ def auth_app(mock_oauth_provider):


@pytest.fixture
def test_client(auth_app) -> httpx.AsyncClient:
return httpx.AsyncClient(
async def test_client(auth_app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
)
) as client:
yield client


@pytest.fixture
Expand Down Expand Up @@ -993,163 +987,7 @@ async def test_client_registration_invalid_grant_type(
)


class TestFastMCPWithAuth:
"""Test FastMCP server with authentication."""

@pytest.mark.anyio
async def test_fastmcp_with_auth(
self, mock_oauth_provider: MockOAuthProvider, pkce_challenge
):
"""Test creating a FastMCP server with authentication."""
# Create FastMCP server with auth provider
mcp = FastMCP(
auth_server_provider=mock_oauth_provider,
require_auth=True,
auth=AuthSettings(
issuer_url=AnyHttpUrl("https://auth.example.com"),
client_registration_options=ClientRegistrationOptions(enabled=True),
revocation_options=RevocationOptions(enabled=True),
required_scopes=["read", "write"],
),
)

# Add a test tool
@mcp.tool()
def test_tool(x: int) -> str:
return f"Result: {x}"

async with anyio.create_task_group() as task_group:
transport = StreamingASGITransport(
app=mcp.sse_app(),
task_group=task_group,
)
test_client = httpx.AsyncClient(
transport=transport, base_url="http://mcptest.com"
)

# Test metadata endpoint
response = await test_client.get("/.well-known/oauth-authorization-server")
assert response.status_code == 200

# Test that auth is required for protected endpoints
response = await test_client.get("/sse")
assert response.status_code == 401

response = await test_client.post("/messages/")
assert response.status_code == 401, response.content

response = await test_client.post(
"/messages/",
headers={"Authorization": "invalid"},
)
assert response.status_code == 401

response = await test_client.post(
"/messages/",
headers={"Authorization": "Bearer invalid"},
)
assert response.status_code == 401

# now, become authenticated and try to go through the flow again
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Test Client",
}

response = await test_client.post(
"/register",
json=client_metadata,
)
assert response.status_code == 201
client_info = response.json()

# Request authorization using POST with form-encoded data
response = await test_client.post(
"/authorize",
data={
"response_type": "code",
"client_id": client_info["client_id"],
"redirect_uri": "https://client.example.com/callback",
"code_challenge": pkce_challenge["code_challenge"],
"code_challenge_method": "S256",
"state": "test_state",
},
)
assert response.status_code == 302

# Extract the authorization code from the redirect URL
redirect_url = response.headers["location"]
parsed_url = urlparse(redirect_url)
query_params = parse_qs(parsed_url.query)

assert "code" in query_params
auth_code = query_params["code"][0]

# Exchange the authorization code for tokens
response = await test_client.post(
"/token",
data={
"grant_type": "authorization_code",
"client_id": client_info["client_id"],
"client_secret": client_info["client_secret"],
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 200

token_response = response.json()
assert "access_token" in token_response
authorization = f"Bearer {token_response['access_token']}"

# Test the authenticated endpoint with valid token
async with aconnect_sse(
test_client, "GET", "/sse", headers={"Authorization": authorization}
) as event_source:
assert event_source.response.status_code == 200
events = event_source.aiter_sse()
sse = await events.__anext__()
assert sse.event == "endpoint"
assert sse.data.startswith("/messages/?session_id=")
messages_uri = sse.data

# verify that we can now post to the /messages endpoint,
# and get a response on the /sse endpoint
response = await test_client.post(
messages_uri,
headers={"Authorization": authorization},
content=JSONRPCRequest(
jsonrpc="2.0",
id="123",
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {
"roots": {"listChanged": True},
"sampling": {},
},
"clientInfo": {"name": "ExampleClient", "version": "1.0.0"},
},
).model_dump_json(),
)
assert response.status_code == 202
assert response.content == b"Accepted"

sse = await events.__anext__()
assert sse.event == "message"
sse_data = json.loads(sse.data)
assert sse_data["id"] == "123"
assert set(sse_data["result"]["capabilities"].keys()) == {
"experimental",
"prompts",
"resources",
"tools",
}
# the /sse endpoint will never finish; normally, the client could just
# disconnect, but in tests the easiest way to do this is to cancel the
# task group
task_group.cancel_scope.cancel()


class TestAuthorizeEndpointErrors:
Expand Down
24 changes: 15 additions & 9 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,20 @@ def server(server_port: int) -> Generator[None, None, None]:
print("server process failed to terminate")


@pytest.fixture()
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client"""
async with httpx.AsyncClient(base_url=server_url) as client:
yield client


# Tests
@pytest.mark.anyio
@pytest.mark.skip(
"fails in CI, but works locally. Need to investigate why."
)
async def test_raw_sse_connection(server, server_url) -> None:
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
"""Test the SSE connection establishment simply with an HTTP client."""
try:
async with httpx.AsyncClient(base_url=server_url) as http_client:
async with anyio.create_task_group():

async def connection_test() -> None:
async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200
assert (
Expand All @@ -176,8 +181,9 @@ async def test_raw_sse_connection(server, server_url) -> None:
return
line_number += 1

except Exception as e:
pytest.fail(f"{e}")
# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
await connection_test()


@pytest.mark.anyio
Expand Down Expand Up @@ -243,4 +249,4 @@ async def test_sse_client_timeout(
# we should receive an error here
return

pytest.fail("the client should have timed out and returned an error already")
pytest.fail("the client should have timed out and returned an error already")
Loading