Skip to content

Jerome/auth #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/mcp/server/auth/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
Corresponds to TypeScript file: src/server/auth/errors.ts
"""

from typing import Dict
from typing import TypedDict

from pydantic import ValidationError


class OAuthErrorResponse(TypedDict):
"""OAuth error response format."""

error: str
error_description: str


class OAuthError(Exception):
"""
Base class for all OAuth errors.
Expand All @@ -22,7 +29,7 @@ def __init__(self, message: str):
super().__init__(message)
self.message = message

def to_response_object(self) -> Dict[str, str]:
def to_response_object(self) -> OAuthErrorResponse:
"""Convert error to JSON response object."""
return {"error": self.error_code, "error_description": self.message}

Expand Down Expand Up @@ -146,5 +153,9 @@ class InsufficientScopeError(OAuthError):

error_code = "insufficient_scope"


def stringify_pydantic_error(validation_error: ValidationError) -> str:
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
return "\n".join(
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
for e in validation_error.errors()
)
94 changes: 57 additions & 37 deletions src/mcp/server/auth/handlers/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts
"""

import logging
from typing import Callable, Literal, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from urllib.parse import urlencode, urlparse, urlunparse

from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response

from mcp.server.auth.errors import (
InvalidClientError,
InvalidRequestError,
OAuthError,
stringify_pydantic_error,
)
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri
from mcp.shared.auth import OAuthClientInformationFull
from mcp.server.auth.json_response import PydanticJSONResponse

import logging
from mcp.server.auth.provider import (
AuthorizationParams,
OAuthServerProvider,
construct_redirect_uri,
)
from mcp.shared.auth import OAuthClientInformationFull

logger = logging.getLogger(__name__)

Expand All @@ -48,7 +50,6 @@ class AuthorizationRequest(BaseModel):
description="Optional scope; if specified, should be "
"a space-separated list of scope strings",
)



def validate_scope(
Expand Down Expand Up @@ -80,30 +81,38 @@ def validate_redirect_uri(
raise InvalidRequestError(
"redirect_uri must be specified when client has multiple registered URIs"
)


ErrorCode = Literal[
"invalid_request",
"unauthorized_client",
"access_denied",
"unsupported_response_type",
"invalid_scope",
"server_error",
"temporarily_unavailable"
]
"invalid_request",
"unauthorized_client",
"access_denied",
"unsupported_response_type",
"invalid_scope",
"server_error",
"temporarily_unavailable",
]


class ErrorResponse(BaseModel):
error: ErrorCode
error_description: str
error_uri: Optional[AnyUrl] = None
# must be set if provided in the request
state: Optional[str]

def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]:

def best_effort_extract_string(
key: str, params: None | FormData | QueryParams
) -> Optional[str]:
if params is None:
return None
value = params.get(key)
if isinstance(value, str):
return value
return None


class AnyHttpUrlModel(RootModel):
root: AnyHttpUrl

Expand All @@ -118,18 +127,24 @@ async def authorization_handler(request: Request) -> Response:
client = None
params = None

async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True):
async def error_response(
error: ErrorCode, error_description: str, attempt_load_client: bool = True
):
nonlocal client, redirect_uri, state
if client is None and attempt_load_client:
# make last-ditch attempt to load the client
client_id = best_effort_extract_string("client_id", params)
client = client_id and await provider.clients_store.get_client(client_id)
client = client_id and await provider.clients_store.get_client(
client_id
)
if redirect_uri is None and client:
# make last-ditch effort to load the redirect uri
if params is not None and "redirect_uri" not in params:
raw_redirect_uri = None
else:
raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root
raw_redirect_uri = AnyHttpUrlModel.model_validate(
best_effort_extract_string("redirect_uri", params)
).root
try:
redirect_uri = validate_redirect_uri(raw_redirect_uri, client)
except (ValidationError, InvalidRequestError):
Expand All @@ -146,7 +161,9 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_

if redirect_uri and client:
return RedirectResponse(
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
url=construct_redirect_uri(
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
),
status_code=302,
headers={"Cache-Control": "no-store"},
)
Expand All @@ -156,7 +173,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_
content=error_resp,
headers={"Cache-Control": "no-store"},
)

try:
# Parse request parameters
if request.method == "GET":
Expand All @@ -165,20 +182,22 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_
else:
# Parse form data for POST requests
params = await request.form()

# Save state if it exists, even before validation
state = best_effort_extract_string("state", params)

try:
auth_request = AuthorizationRequest.model_validate(params)
state = auth_request.state # Update with validated state
except ValidationError as validation_error:
error: ErrorCode = "invalid_request"
for e in validation_error.errors():
if e['loc'] == ('response_type',) and e['type'] == 'literal_error':
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
error = "unsupported_response_type"
break
return await error_response(error, stringify_pydantic_error(validation_error))
return await error_response(
error, stringify_pydantic_error(validation_error)
)

# Get client information
client = await provider.clients_store.get_client(auth_request.client_id)
Expand All @@ -190,7 +209,6 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_
attempt_load_client=False,
)


# Validate redirect_uri against client's registered URIs
try:
redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client)
Expand All @@ -200,7 +218,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_
error="invalid_request",
error_description=validation_error.message,
)

# Validate scope - for scope errors, we can redirect
try:
scopes = validate_scope(auth_request.scope, client)
Expand All @@ -210,28 +228,30 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_
error="invalid_scope",
error_description=validation_error.message,
)

# Setup authorization parameters
auth_params = AuthorizationParams(
state=state,
scopes=scopes,
code_challenge=auth_request.code_challenge,
redirect_uri=redirect_uri,
)

# Let the provider pick the next URI to redirect to
response = RedirectResponse(
url="", status_code=302, headers={"Cache-Control": "no-store"}
)
response.headers["location"] = await provider.authorize(
client, auth_params
)
response.headers["location"] = await provider.authorize(client, auth_params)
return response

except Exception as validation_error:
# Catch-all for unexpected errors
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
return await error_response(error="server_error", error_description="An unexpected error occurred")
logger.exception(
"Unexpected error in authorization_handler", exc_info=validation_error
)
return await error_response(
error="server_error", error_description="An unexpected error occurred"
)

return authorization_handler

Expand All @@ -240,7 +260,7 @@ def create_error_redirect(
redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse]
) -> str:
parsed_uri = urlparse(str(redirect_uri))

if isinstance(error, ErrorResponse):
# Convert ErrorResponse to dict
error_dict = error.model_dump(exclude_none=True)
Expand All @@ -251,7 +271,7 @@ def create_error_redirect(
query_params[key] = str(value)
else:
query_params[key] = value

elif isinstance(error, OAuthError):
query_params = {"error": error.error_code, "error_description": str(error)}
else:
Expand Down
11 changes: 7 additions & 4 deletions src/mcp/server/auth/handlers/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts
"""

from typing import Any, Callable, Dict
from typing import Callable

from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from mcp.shared.auth import OAuthMetadata

def create_metadata_handler(metadata: Dict[str, Any]) -> Callable:

def create_metadata_handler(metadata: OAuthMetadata) -> Callable:
"""
Create a handler for OAuth 2.0 Authorization Server Metadata.

Expand All @@ -33,8 +35,9 @@ async def metadata_handler(request: Request) -> Response:
Returns:
JSON response with the authorization server metadata
"""
# Remove any None values from metadata
clean_metadata = {k: v for k, v in metadata.items() if v is not None}
# Convert metadata to dict and remove any None values
metadata_dict = metadata.model_dump()
clean_metadata = {k: v for k, v in metadata_dict.items() if v is not None}

return JSONResponse(
content=clean_metadata,
Expand Down
Loading
Loading