|
4 | 4 |
|
5 | 5 | import base64
|
6 | 6 | import hashlib
|
7 |
| -import json |
8 | 7 | import secrets
|
9 | 8 | import time
|
10 | 9 | import unittest.mock
|
11 | 10 | from urllib.parse import parse_qs, urlparse
|
12 | 11 |
|
13 |
| -import anyio |
14 | 12 | import httpx
|
15 | 13 | import pytest
|
16 |
| -from httpx_sse import aconnect_sse |
17 | 14 | from pydantic import AnyHttpUrl
|
18 | 15 | from starlette.applications import Starlette
|
19 | 16 |
|
|
30 | 27 | RevocationOptions,
|
31 | 28 | create_auth_routes,
|
32 | 29 | )
|
33 |
| -from mcp.server.auth.settings import AuthSettings |
34 |
| -from mcp.server.fastmcp import FastMCP |
35 |
| -from mcp.server.streaming_asgi_transport import StreamingASGITransport |
36 | 30 | from mcp.shared.auth import (
|
37 | 31 | OAuthClientInformationFull,
|
38 | 32 | OAuthToken,
|
39 | 33 | )
|
40 |
| -from mcp.types import JSONRPCRequest |
41 | 34 |
|
42 | 35 |
|
43 | 36 | # Mock OAuth provider for testing
|
@@ -230,10 +223,11 @@ def auth_app(mock_oauth_provider):
|
230 | 223 |
|
231 | 224 |
|
232 | 225 | @pytest.fixture
|
233 |
| -def test_client(auth_app) -> httpx.AsyncClient: |
234 |
| - return httpx.AsyncClient( |
| 226 | +async def test_client(auth_app): |
| 227 | + async with httpx.AsyncClient( |
235 | 228 | transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
|
236 |
| - ) |
| 229 | + ) as client: |
| 230 | + yield client |
237 | 231 |
|
238 | 232 |
|
239 | 233 | @pytest.fixture
|
@@ -993,163 +987,7 @@ async def test_client_registration_invalid_grant_type(
|
993 | 987 | )
|
994 | 988 |
|
995 | 989 |
|
996 |
| -class TestFastMCPWithAuth: |
997 |
| - """Test FastMCP server with authentication.""" |
998 |
| - |
999 |
| - @pytest.mark.anyio |
1000 |
| - async def test_fastmcp_with_auth( |
1001 |
| - self, mock_oauth_provider: MockOAuthProvider, pkce_challenge |
1002 |
| - ): |
1003 |
| - """Test creating a FastMCP server with authentication.""" |
1004 |
| - # Create FastMCP server with auth provider |
1005 |
| - mcp = FastMCP( |
1006 |
| - auth_server_provider=mock_oauth_provider, |
1007 |
| - require_auth=True, |
1008 |
| - auth=AuthSettings( |
1009 |
| - issuer_url=AnyHttpUrl("https://auth.example.com"), |
1010 |
| - client_registration_options=ClientRegistrationOptions(enabled=True), |
1011 |
| - revocation_options=RevocationOptions(enabled=True), |
1012 |
| - required_scopes=["read", "write"], |
1013 |
| - ), |
1014 |
| - ) |
1015 |
| - |
1016 |
| - # Add a test tool |
1017 |
| - @mcp.tool() |
1018 |
| - def test_tool(x: int) -> str: |
1019 |
| - return f"Result: {x}" |
1020 |
| - |
1021 |
| - async with anyio.create_task_group() as task_group: |
1022 |
| - transport = StreamingASGITransport( |
1023 |
| - app=mcp.sse_app(), |
1024 |
| - task_group=task_group, |
1025 |
| - ) |
1026 |
| - test_client = httpx.AsyncClient( |
1027 |
| - transport=transport, base_url="http://mcptest.com" |
1028 |
| - ) |
1029 |
| - |
1030 |
| - # Test metadata endpoint |
1031 |
| - response = await test_client.get("/.well-known/oauth-authorization-server") |
1032 |
| - assert response.status_code == 200 |
1033 | 990 |
|
1034 |
| - # Test that auth is required for protected endpoints |
1035 |
| - response = await test_client.get("/sse") |
1036 |
| - assert response.status_code == 401 |
1037 |
| - |
1038 |
| - response = await test_client.post("/messages/") |
1039 |
| - assert response.status_code == 401, response.content |
1040 |
| - |
1041 |
| - response = await test_client.post( |
1042 |
| - "/messages/", |
1043 |
| - headers={"Authorization": "invalid"}, |
1044 |
| - ) |
1045 |
| - assert response.status_code == 401 |
1046 |
| - |
1047 |
| - response = await test_client.post( |
1048 |
| - "/messages/", |
1049 |
| - headers={"Authorization": "Bearer invalid"}, |
1050 |
| - ) |
1051 |
| - assert response.status_code == 401 |
1052 |
| - |
1053 |
| - # now, become authenticated and try to go through the flow again |
1054 |
| - client_metadata = { |
1055 |
| - "redirect_uris": ["https://client.example.com/callback"], |
1056 |
| - "client_name": "Test Client", |
1057 |
| - } |
1058 |
| - |
1059 |
| - response = await test_client.post( |
1060 |
| - "/register", |
1061 |
| - json=client_metadata, |
1062 |
| - ) |
1063 |
| - assert response.status_code == 201 |
1064 |
| - client_info = response.json() |
1065 |
| - |
1066 |
| - # Request authorization using POST with form-encoded data |
1067 |
| - response = await test_client.post( |
1068 |
| - "/authorize", |
1069 |
| - data={ |
1070 |
| - "response_type": "code", |
1071 |
| - "client_id": client_info["client_id"], |
1072 |
| - "redirect_uri": "https://client.example.com/callback", |
1073 |
| - "code_challenge": pkce_challenge["code_challenge"], |
1074 |
| - "code_challenge_method": "S256", |
1075 |
| - "state": "test_state", |
1076 |
| - }, |
1077 |
| - ) |
1078 |
| - assert response.status_code == 302 |
1079 |
| - |
1080 |
| - # Extract the authorization code from the redirect URL |
1081 |
| - redirect_url = response.headers["location"] |
1082 |
| - parsed_url = urlparse(redirect_url) |
1083 |
| - query_params = parse_qs(parsed_url.query) |
1084 |
| - |
1085 |
| - assert "code" in query_params |
1086 |
| - auth_code = query_params["code"][0] |
1087 |
| - |
1088 |
| - # Exchange the authorization code for tokens |
1089 |
| - response = await test_client.post( |
1090 |
| - "/token", |
1091 |
| - data={ |
1092 |
| - "grant_type": "authorization_code", |
1093 |
| - "client_id": client_info["client_id"], |
1094 |
| - "client_secret": client_info["client_secret"], |
1095 |
| - "code": auth_code, |
1096 |
| - "code_verifier": pkce_challenge["code_verifier"], |
1097 |
| - "redirect_uri": "https://client.example.com/callback", |
1098 |
| - }, |
1099 |
| - ) |
1100 |
| - assert response.status_code == 200 |
1101 |
| - |
1102 |
| - token_response = response.json() |
1103 |
| - assert "access_token" in token_response |
1104 |
| - authorization = f"Bearer {token_response['access_token']}" |
1105 |
| - |
1106 |
| - # Test the authenticated endpoint with valid token |
1107 |
| - async with aconnect_sse( |
1108 |
| - test_client, "GET", "/sse", headers={"Authorization": authorization} |
1109 |
| - ) as event_source: |
1110 |
| - assert event_source.response.status_code == 200 |
1111 |
| - events = event_source.aiter_sse() |
1112 |
| - sse = await events.__anext__() |
1113 |
| - assert sse.event == "endpoint" |
1114 |
| - assert sse.data.startswith("/messages/?session_id=") |
1115 |
| - messages_uri = sse.data |
1116 |
| - |
1117 |
| - # verify that we can now post to the /messages endpoint, |
1118 |
| - # and get a response on the /sse endpoint |
1119 |
| - response = await test_client.post( |
1120 |
| - messages_uri, |
1121 |
| - headers={"Authorization": authorization}, |
1122 |
| - content=JSONRPCRequest( |
1123 |
| - jsonrpc="2.0", |
1124 |
| - id="123", |
1125 |
| - method="initialize", |
1126 |
| - params={ |
1127 |
| - "protocolVersion": "2024-11-05", |
1128 |
| - "capabilities": { |
1129 |
| - "roots": {"listChanged": True}, |
1130 |
| - "sampling": {}, |
1131 |
| - }, |
1132 |
| - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, |
1133 |
| - }, |
1134 |
| - ).model_dump_json(), |
1135 |
| - ) |
1136 |
| - assert response.status_code == 202 |
1137 |
| - assert response.content == b"Accepted" |
1138 |
| - |
1139 |
| - sse = await events.__anext__() |
1140 |
| - assert sse.event == "message" |
1141 |
| - sse_data = json.loads(sse.data) |
1142 |
| - assert sse_data["id"] == "123" |
1143 |
| - assert set(sse_data["result"]["capabilities"].keys()) == { |
1144 |
| - "experimental", |
1145 |
| - "prompts", |
1146 |
| - "resources", |
1147 |
| - "tools", |
1148 |
| - } |
1149 |
| - # the /sse endpoint will never finish; normally, the client could just |
1150 |
| - # disconnect, but in tests the easiest way to do this is to cancel the |
1151 |
| - # task group |
1152 |
| - task_group.cancel_scope.cancel() |
1153 | 991 |
|
1154 | 992 |
|
1155 | 993 | class TestAuthorizeEndpointErrors:
|
|
0 commit comments