Skip to content

Commit 83e488e

Browse files
Tidying up types to avoid typing module, fixing type errors, updating some signatures to validation functions
1 parent 84bfe75 commit 83e488e

File tree

8 files changed

+35
-33
lines changed

8 files changed

+35
-33
lines changed

src/mcp/server/auth/handlers/authorize.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,29 @@ class AuthorizationRequest(BaseModel):
4343
class Config:
4444
extra = "ignore"
4545

46-
def validate_scope(requested_scope: str | None, client: OAuthClientInformationFull) -> list[str] | None:
46+
def validate_scope(requested_scope: str | None, scope: str | None) -> list[str] | None:
4747
if requested_scope is None:
4848
return None
4949
requested_scopes = requested_scope.split(" ")
50-
allowed_scopes = [] if client.scope is None else client.scope.split(" ")
50+
allowed_scopes = [] if scope is None else scope.split(" ")
5151
for scope in requested_scopes:
5252
if scope not in allowed_scopes:
5353
raise InvalidRequestError(f"Client was not registered with scope {scope}")
5454
return requested_scopes
5555

56-
def validate_redirect_uri(auth_request: AuthorizationRequest, client: OAuthClientInformationFull) -> AnyHttpUrl:
57-
if auth_request.redirect_uri is not None:
56+
def validate_redirect_uri(redirect_uri: AnyHttpUrl | None, redirect_uris: list[AnyHttpUrl]) -> AnyHttpUrl:
57+
if not redirect_uris:
58+
raise InvalidClientError("Client has no registered redirect URIs")
59+
60+
if redirect_uri is not None:
5861
# Validate redirect_uri against client's registered redirect URIs
59-
if auth_request.redirect_uri not in client.redirect_uris:
62+
if redirect_uri not in redirect_uris:
6063
raise InvalidRequestError(
61-
f"Redirect URI '{auth_request.redirect_uri}' not registered for client"
64+
f"Redirect URI '{redirect_uri}' not registered for client"
6265
)
63-
return auth_request.redirect_uri
64-
elif len(client.redirect_uris) == 1:
65-
return client.redirect_uris[0]
66+
return redirect_uri
67+
elif len(redirect_uris) == 1:
68+
return redirect_uris[0]
6669
else:
6770
raise InvalidRequestError("redirect_uri must be specified when client has multiple registered URIs")
6871

@@ -104,8 +107,8 @@ async def authorization_handler(request: Request) -> Response:
104107

105108

106109
# do validation which is dependent on the client configuration
107-
redirect_uri = validate_redirect_uri(auth_request, client)
108-
scopes = validate_scope(auth_request.scope, client)
110+
redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client.redirect_uris)
111+
scopes = validate_scope(auth_request.scope, client.scope)
109112

110113
auth_params = AuthorizationParams(
111114
state=auth_request.state,

src/mcp/server/auth/handlers/revoke.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts
55
"""
66

7-
from typing import Any, Callable, Dict, Optional
7+
from typing import Any, Callable
88

99
from starlette.requests import Request
1010
from starlette.responses import Response

src/mcp/server/auth/handlers/token.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import base64
88
import hashlib
99
import json
10-
from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union
10+
from typing import Annotated, Any, Callable, Literal, Union
1111

1212
from starlette.requests import Request
1313
from starlette.responses import JSONResponse
@@ -44,7 +44,7 @@ class RefreshTokenRequest(ClientAuthRequest):
4444
"""
4545
grant_type: Literal["refresh_token"]
4646
refresh_token: str = Field(..., description="The refresh token")
47-
scope: Optional[str] = Field(None, description="Optional scope parameter")
47+
scope: str | None = Field(None, description="Optional scope parameter")
4848

4949

5050
class TokenRequest(RootModel):

src/mcp/server/auth/middleware/bearer_auth.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import time
8-
from typing import List, Optional, Callable, Awaitable, cast, Dict, Any
8+
from typing import Any, Callable, cast
99

1010
from starlette.requests import HTTPConnection, Request
1111
from starlette.exceptions import HTTPException

src/mcp/server/auth/middleware/client_auth.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import time
8-
from typing import Optional, Dict, Any, Callable
8+
from typing import Any, Callable
99

1010
from starlette.requests import Request
1111
from starlette.exceptions import HTTPException
@@ -28,7 +28,7 @@ class ClientAuthRequest(BaseModel):
2828
Corresponds to ClientAuthenticatedRequestSchema in src/server/auth/middleware/clientAuth.ts
2929
"""
3030
client_id: str
31-
client_secret: Optional[str] = None
31+
client_secret: str | None = None
3232

3333

3434
class ClientAuthenticator:
@@ -94,7 +94,7 @@ def __init__(
9494
self.app = app
9595
self.client_auth = ClientAuthenticator(clients_store)
9696

97-
async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None:
97+
async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None:
9898
"""
9999
Process the request and authenticate the client.
100100
@@ -112,7 +112,7 @@ async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None
112112

113113
# Add client authentication to the request
114114
try:
115-
client = await self.client_auth(request)
115+
client = await self.client_auth(ClientAuthRequest.model_validate(request))
116116
# Store the client in the request state
117117
request.state.client = client
118118
except HTTPException:

src/mcp/server/auth/provider.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Corresponds to TypeScript file: src/server/auth/provider.ts
55
"""
66

7-
from typing import Any, Dict, List, Optional, Protocol
7+
from typing import Any, Protocol
88
from pydantic import AnyHttpUrl, BaseModel
99
from starlette.responses import Response
1010

@@ -18,8 +18,8 @@ class AuthorizationParams(BaseModel):
1818
1919
Corresponds to AuthorizationParams in src/server/auth/provider.ts
2020
"""
21-
state: Optional[str] = None
22-
scopes: Optional[List[str]] = None
21+
state: str | None = None
22+
scopes: list[str] | None = None
2323
code_challenge: str
2424
redirect_uri: AnyHttpUrl
2525

@@ -31,7 +31,7 @@ class OAuthRegisteredClientsStore(Protocol):
3131
Corresponds to OAuthRegisteredClientsStore in src/server/auth/clients.ts
3232
"""
3333

34-
async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]:
34+
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
3535
"""
3636
Retrieves client information by client ID.
3737
@@ -45,7 +45,7 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul
4545

4646
async def register_client(self,
4747
client_info: OAuthClientInformationFull
48-
) -> Optional[OAuthClientInformationFull]:
48+
) -> OAuthClientInformationFull | None:
4949
"""
5050
Registers a new client and returns client information.
5151
@@ -121,7 +121,7 @@ async def exchange_authorization_code(self,
121121
async def exchange_refresh_token(self,
122122
client: OAuthClientInformationFull,
123123
refresh_token: str,
124-
scopes: Optional[List[str]] = None) -> OAuthTokens:
124+
scopes: list[str] | None = None) -> OAuthTokens:
125125
"""
126126
Exchanges a refresh token for an access token.
127127

src/mcp/server/auth/router.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dataclasses import dataclass
88
import re
9-
from typing import Dict, List, Optional, Any, Union, Callable
9+
from typing import Any, Callable
1010
from urllib.parse import urlparse
1111

1212
from starlette.routing import Route, Router
@@ -26,7 +26,7 @@
2626
@dataclass
2727
class ClientRegistrationOptions:
2828
enabled: bool = False
29-
client_secret_expiry_seconds: Optional[int] = None
29+
client_secret_expiry_seconds: int | None = None
3030

3131
@dataclass
3232
class RevocationOptions:
@@ -145,10 +145,10 @@ def create_auth_router(
145145

146146
def build_metadata(
147147
issuer_url: AnyUrl,
148-
service_documentation_url: Optional[AnyUrl],
148+
service_documentation_url: AnyUrl | None,
149149
client_registration_options: ClientRegistrationOptions,
150150
revocation_options: RevocationOptions,
151-
) -> Dict[str, Any]:
151+
) -> dict[str, Any]:
152152
issuer_url_str = str(issuer_url).rstrip("/")
153153
# Create metadata
154154
metadata = {

src/mcp/server/auth/types.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Corresponds to TypeScript file: src/server/auth/types.ts
55
"""
66

7-
from typing import List, Optional
87
from pydantic import BaseModel
98

109

@@ -16,9 +15,9 @@ class AuthInfo(BaseModel):
1615
"""
1716
token: str
1817
client_id: str
19-
scopes: List[str]
20-
expires_at: Optional[int] = None
21-
user_id: Optional[str] = None
18+
scopes: list[str]
19+
expires_at: int | None = None
20+
user_id: str | None = None
2221

2322
class Config:
2423
extra = "ignore"

0 commit comments

Comments
 (0)