Skip to content

Commit 2210c1b

Browse files
praboud-antdsp-antbhosmer-antihrpr
authored
Add support for serverside oauth (#255)
Co-authored-by: David Soria Parra <[email protected]> Co-authored-by: Basil Hosmer <[email protected]> Co-authored-by: ihrpr <[email protected]>
1 parent 82bd8bc commit 2210c1b

31 files changed

+4120
-22
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,5 @@ cython_debug/
166166

167167
# vscode
168168
.vscode/
169+
.windsurfrules
169170
**/CLAUDE.local.md

CLAUDE.md

+9-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This document contains critical information about working with this codebase. Fo
1919
- Line length: 88 chars maximum
2020

2121
3. Testing Requirements
22-
- Framework: `uv run pytest`
22+
- Framework: `uv run --frozen pytest`
2323
- Async testing: use anyio, not asyncio
2424
- Coverage: test edge cases and errors
2525
- New features require tests
@@ -54,9 +54,9 @@ This document contains critical information about working with this codebase. Fo
5454
## Code Formatting
5555

5656
1. Ruff
57-
- Format: `uv run ruff format .`
58-
- Check: `uv run ruff check .`
59-
- Fix: `uv run ruff check . --fix`
57+
- Format: `uv run --frozen ruff format .`
58+
- Check: `uv run --frozen ruff check .`
59+
- Fix: `uv run --frozen ruff check . --fix`
6060
- Critical issues:
6161
- Line length (88 chars)
6262
- Import sorting (I001)
@@ -67,7 +67,7 @@ This document contains critical information about working with this codebase. Fo
6767
- Imports: split into multiple lines
6868

6969
2. Type Checking
70-
- Tool: `uv run pyright`
70+
- Tool: `uv run --frozen pyright`
7171
- Requirements:
7272
- Explicit None checks for Optional
7373
- Type narrowing for strings
@@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo
104104
- Add None checks
105105
- Narrow string types
106106
- Match existing patterns
107+
- Pytest:
108+
- If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD=""
109+
to the start of the pytest run command eg:
110+
`PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest`
107111

108112
3. Best Practices
109113
- Check git status before commits

README.md

+27
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,33 @@ async def long_task(files: list[str], ctx: Context) -> str:
309309
return "Processing complete"
310310
```
311311

312+
### Authentication
313+
314+
Authentication can be used by servers that want to expose tools accessing protected resources.
315+
316+
`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by
317+
providing an implementation of the `OAuthServerProvider` protocol.
318+
319+
```
320+
mcp = FastMCP("My App",
321+
auth_provider=MyOAuthServerProvider(),
322+
auth=AuthSettings(
323+
issuer_url="https://myapp.com",
324+
revocation_options=RevocationOptions(
325+
enabled=True,
326+
),
327+
client_registration_options=ClientRegistrationOptions(
328+
enabled=True,
329+
valid_scopes=["myscope", "myotherscope"],
330+
default_scopes=["myscope"],
331+
),
332+
required_scopes=["myscope"],
333+
),
334+
)
335+
```
336+
337+
See [OAuthServerProvider](mcp/server/auth/provider.py) for more details.
338+
312339
## Running Your Server
313340

314341
### Development Mode

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ async def process_llm_response(self, llm_response: str) -> str:
323323
total = result["total"]
324324
percentage = (progress / total) * 100
325325
logging.info(
326-
f"Progress: {progress}/{total} "
327-
f"({percentage:.1f}%)"
326+
f"Progress: {progress}/{total} ({percentage:.1f}%)"
328327
)
329328

330329
return f"Tool execution result: {result}"

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"httpx-sse>=0.4",
2828
"pydantic>=2.7.2,<3.0.0",
2929
"starlette>=0.27",
30+
"python-multipart>=0.0.9",
3031
"sse-starlette>=1.6.1",
3132
"pydantic-settings>=2.5.2",
3233
"uvicorn>=0.23.1; sys_platform != 'emscripten'",

src/mcp/server/auth/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
MCP OAuth server authorization components.
3+
"""

src/mcp/server/auth/errors.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pydantic import ValidationError
2+
3+
4+
def stringify_pydantic_error(validation_error: ValidationError) -> str:
5+
return "\n".join(
6+
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
7+
for e in validation_error.errors()
8+
)
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Request handlers for MCP authorization endpoints.
3+
"""
+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from typing import Any, Literal
4+
5+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
6+
from starlette.datastructures import FormData, QueryParams
7+
from starlette.requests import Request
8+
from starlette.responses import RedirectResponse, Response
9+
10+
from mcp.server.auth.errors import (
11+
stringify_pydantic_error,
12+
)
13+
from mcp.server.auth.json_response import PydanticJSONResponse
14+
from mcp.server.auth.provider import (
15+
AuthorizationErrorCode,
16+
AuthorizationParams,
17+
AuthorizeError,
18+
OAuthAuthorizationServerProvider,
19+
construct_redirect_uri,
20+
)
21+
from mcp.shared.auth import (
22+
InvalidRedirectUriError,
23+
InvalidScopeError,
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class AuthorizationRequest(BaseModel):
30+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
31+
client_id: str = Field(..., description="The client ID")
32+
redirect_uri: AnyHttpUrl | None = Field(
33+
None, description="URL to redirect to after authorization"
34+
)
35+
36+
# see OAuthClientMetadata; we only support `code`
37+
response_type: Literal["code"] = Field(
38+
..., description="Must be 'code' for authorization code flow"
39+
)
40+
code_challenge: str = Field(..., description="PKCE code challenge")
41+
code_challenge_method: Literal["S256"] = Field(
42+
"S256", description="PKCE code challenge method, must be S256"
43+
)
44+
state: str | None = Field(None, description="Optional state parameter")
45+
scope: str | None = Field(
46+
None,
47+
description="Optional scope; if specified, should be "
48+
"a space-separated list of scope strings",
49+
)
50+
51+
52+
class AuthorizationErrorResponse(BaseModel):
53+
error: AuthorizationErrorCode
54+
error_description: str | None
55+
error_uri: AnyUrl | None = None
56+
# must be set if provided in the request
57+
state: str | None = None
58+
59+
60+
def best_effort_extract_string(
61+
key: str, params: None | FormData | QueryParams
62+
) -> str | None:
63+
if params is None:
64+
return None
65+
value = params.get(key)
66+
if isinstance(value, str):
67+
return value
68+
return None
69+
70+
71+
class AnyHttpUrlModel(RootModel[AnyHttpUrl]):
72+
root: AnyHttpUrl
73+
74+
75+
@dataclass
76+
class AuthorizationHandler:
77+
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
78+
79+
async def handle(self, request: Request) -> Response:
80+
# implements authorization requests for grant_type=code;
81+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
82+
83+
state = None
84+
redirect_uri = None
85+
client = None
86+
params = None
87+
88+
async def error_response(
89+
error: AuthorizationErrorCode,
90+
error_description: str | None,
91+
attempt_load_client: bool = True,
92+
):
93+
# Error responses take two different formats:
94+
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
95+
# back to the redirect_uri with the error response fields as query
96+
# parameters. This allows the client to be notified of the error.
97+
# 2. Otherwise, we return an error response directly to the end user;
98+
# we choose to do so in JSON, but this is left undefined in the
99+
# specification.
100+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
101+
#
102+
# This logic is a bit awkward to handle, because the error might be thrown
103+
# very early in request validation, before we've done the usual Pydantic
104+
# validation, loaded the client, etc. To handle this, error_response()
105+
# contains fallback logic which attempts to load the parameters directly
106+
# from the request.
107+
108+
nonlocal client, redirect_uri, state
109+
if client is None and attempt_load_client:
110+
# make last-ditch attempt to load the client
111+
client_id = best_effort_extract_string("client_id", params)
112+
client = client_id and await self.provider.get_client(client_id)
113+
if redirect_uri is None and client:
114+
# make last-ditch effort to load the redirect uri
115+
try:
116+
if params is not None and "redirect_uri" not in params:
117+
raw_redirect_uri = None
118+
else:
119+
raw_redirect_uri = AnyHttpUrlModel.model_validate(
120+
best_effort_extract_string("redirect_uri", params)
121+
).root
122+
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
123+
except (ValidationError, InvalidRedirectUriError):
124+
# if the redirect URI is invalid, ignore it & just return the
125+
# initial error
126+
pass
127+
128+
# the error response MUST contain the state specified by the client, if any
129+
if state is None:
130+
# make last-ditch effort to load state
131+
state = best_effort_extract_string("state", params)
132+
133+
error_resp = AuthorizationErrorResponse(
134+
error=error,
135+
error_description=error_description,
136+
state=state,
137+
)
138+
139+
if redirect_uri and client:
140+
return RedirectResponse(
141+
url=construct_redirect_uri(
142+
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
143+
),
144+
status_code=302,
145+
headers={"Cache-Control": "no-store"},
146+
)
147+
else:
148+
return PydanticJSONResponse(
149+
status_code=400,
150+
content=error_resp,
151+
headers={"Cache-Control": "no-store"},
152+
)
153+
154+
try:
155+
# Parse request parameters
156+
if request.method == "GET":
157+
# Convert query_params to dict for pydantic validation
158+
params = request.query_params
159+
else:
160+
# Parse form data for POST requests
161+
params = await request.form()
162+
163+
# Save state if it exists, even before validation
164+
state = best_effort_extract_string("state", params)
165+
166+
try:
167+
auth_request = AuthorizationRequest.model_validate(params)
168+
state = auth_request.state # Update with validated state
169+
except ValidationError as validation_error:
170+
error: AuthorizationErrorCode = "invalid_request"
171+
for e in validation_error.errors():
172+
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
173+
error = "unsupported_response_type"
174+
break
175+
return await error_response(
176+
error, stringify_pydantic_error(validation_error)
177+
)
178+
179+
# Get client information
180+
client = await self.provider.get_client(
181+
auth_request.client_id,
182+
)
183+
if not client:
184+
# For client_id validation errors, return direct error (no redirect)
185+
return await error_response(
186+
error="invalid_request",
187+
error_description=f"Client ID '{auth_request.client_id}' not found",
188+
attempt_load_client=False,
189+
)
190+
191+
# Validate redirect_uri against client's registered URIs
192+
try:
193+
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
194+
except InvalidRedirectUriError as validation_error:
195+
# For redirect_uri validation errors, return direct error (no redirect)
196+
return await error_response(
197+
error="invalid_request",
198+
error_description=validation_error.message,
199+
)
200+
201+
# Validate scope - for scope errors, we can redirect
202+
try:
203+
scopes = client.validate_scope(auth_request.scope)
204+
except InvalidScopeError as validation_error:
205+
# For scope errors, redirect with error parameters
206+
return await error_response(
207+
error="invalid_scope",
208+
error_description=validation_error.message,
209+
)
210+
211+
# Setup authorization parameters
212+
auth_params = AuthorizationParams(
213+
state=state,
214+
scopes=scopes,
215+
code_challenge=auth_request.code_challenge,
216+
redirect_uri=redirect_uri,
217+
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
218+
)
219+
220+
try:
221+
# Let the provider pick the next URI to redirect to
222+
return RedirectResponse(
223+
url=await self.provider.authorize(
224+
client,
225+
auth_params,
226+
),
227+
status_code=302,
228+
headers={"Cache-Control": "no-store"},
229+
)
230+
except AuthorizeError as e:
231+
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
232+
return await error_response(
233+
error=e.error,
234+
error_description=e.error_description,
235+
)
236+
237+
except Exception as validation_error:
238+
# Catch-all for unexpected errors
239+
logger.exception(
240+
"Unexpected error in authorization_handler", exc_info=validation_error
241+
)
242+
return await error_response(
243+
error="server_error", error_description="An unexpected error occurred"
244+
)
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass
2+
3+
from starlette.requests import Request
4+
from starlette.responses import Response
5+
6+
from mcp.server.auth.json_response import PydanticJSONResponse
7+
from mcp.shared.auth import OAuthMetadata
8+
9+
10+
@dataclass
11+
class MetadataHandler:
12+
metadata: OAuthMetadata
13+
14+
async def handle(self, request: Request) -> Response:
15+
return PydanticJSONResponse(
16+
content=self.metadata,
17+
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
18+
)

0 commit comments

Comments
 (0)