Skip to content

Commit 304cc6b

Browse files
committed
update lock
1 parent 97b7789 commit 304cc6b

File tree

3 files changed

+240
-61
lines changed

3 files changed

+240
-61
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33-
"python-multipart",
3433
]
3534

3635
[project.optional-dependencies]
@@ -48,7 +47,6 @@ dev-dependencies = [
4847
"pytest>=8.3.4",
4948
"ruff>=0.8.5",
5049
"trio>=0.26.2",
51-
"pytest-flakefinder>=1.1.0",
5250
"pytest-xdist>=3.6.1",
5351
]
5452

src/mcp/client/auth/oauth.py

Lines changed: 240 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
authorization specification.
77
"""
88

9+
import base64
10+
import hashlib
911
import json
1012
import logging
1113
from datetime import datetime, timedelta
1214
from typing import Any, Protocol
13-
from urllib.parse import urlparse
15+
from urllib.parse import urlencode, urlparse
1416

1517
import httpx
1618
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@@ -373,7 +375,49 @@ class OAuthClientProvider(Protocol):
373375
@property
374376
def client_metadata(self) -> ClientMetadata: ...
375377

376-
def save_client_information(self, metadata: DynamicClientRegistration) -> None: ...
378+
@property
379+
def redirect_url(self) -> AnyHttpUrl: ...
380+
381+
async def open_user_agent(self, url: AnyHttpUrl) -> None:
382+
"""
383+
Opens the user agent to the given URL.
384+
"""
385+
...
386+
387+
async def client_registration(
388+
self, endpoint: AnyHttpUrl
389+
) -> DynamicClientRegistration | None:
390+
"""
391+
Loads the client registration for the given endpoint.
392+
"""
393+
...
394+
395+
async def store_client_registration(
396+
self, endpoint: AnyHttpUrl, metadata: DynamicClientRegistration
397+
) -> None:
398+
"""
399+
Stores the client registration to be retreived for the next session
400+
"""
401+
...
402+
403+
def code_verifier(self) -> str:
404+
"""
405+
Loads the PKCE code verifier for the current session.
406+
See https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1
407+
"""
408+
...
409+
410+
async def token(self) -> AccessToken | None:
411+
"""
412+
Loads the token for the current session.
413+
"""
414+
...
415+
416+
async def store_token(self, token: AccessToken) -> None:
417+
"""
418+
Stores the token to be retreived for the next session
419+
"""
420+
...
377421

378422

379423
class NotFoundError(Exception):
@@ -388,29 +432,64 @@ class RegistrationFailedError(Exception):
388432
pass
389433

390434

435+
class GrantNotSupported(Exception):
436+
"""Exception raised when a grant type is not supported."""
437+
438+
pass
439+
440+
391441
class OAuthClient:
392442
WELL_KNOWN = "/.well-known/oauth-authorization-server"
393-
394-
def __init__(self, server_url: AnyHttpUrl, provider: OAuthClientProvider):
443+
GRANT_TYPE: str = "authorization_code"
444+
445+
def __init__(
446+
self,
447+
server_url: AnyHttpUrl,
448+
provider: OAuthClientProvider,
449+
scope: str | None = None,
450+
):
395451
self.server_url = server_url
396452
self.http_client = httpx.AsyncClient()
397453
self.provider = provider
398-
self._registration: DynamicClientRegistration | None = None
454+
self.scope = scope
399455

400-
async def auth(self):
401-
metadata = await self.discover_auth_metadata() or self._default_metadata()
456+
@property
457+
def discovery_url(self) -> AnyHttpUrl:
458+
base_url = str(self.server_url).rstrip("/")
459+
parsed_url = urlparse(base_url)
460+
# HTTPS is required by RFC 8414
461+
discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}"
462+
return AnyHttpUrl(discovery_url)
463+
464+
async def _obtain_client(
465+
self, metadata: ServerMetadataDiscovery
466+
) -> DynamicClientRegistration:
467+
"""
468+
Obtain a client by either reading it from the OAuthProvider or registering it.
469+
"""
402470
if metadata.registration_endpoint is None:
403471
raise NotFoundError("Registration endpoint not found")
404-
self._registration = await self.dynamic_client_registration(
405-
self.provider.client_metadata, metadata.registration_endpoint
406-
)
407-
if self._registration is None:
408-
raise RegistrationFailedError(
409-
f"Registration at {metadata.registration_endpoint} failed"
472+
473+
if registration := await self.provider.client_registration(metadata.issuer):
474+
return registration
475+
else:
476+
registration = await self.dynamic_client_registration(
477+
self.provider.client_metadata, metadata.registration_endpoint
410478
)
411-
self.provider.save_client_information(self._registration)
479+
if registration is None:
480+
raise RegistrationFailedError(
481+
f"Registration at {metadata.registration_endpoint} failed"
482+
)
412483

413-
def _default_metadata(self) -> ServerMetadataDiscovery:
484+
await self.provider.store_client_registration(metadata.issuer, registration)
485+
return registration
486+
487+
def default_metadata(self) -> ServerMetadataDiscovery:
488+
"""
489+
Returns default endpoints as specified in
490+
https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/
491+
for the server.
492+
"""
414493
base_url = AnyHttpUrl(str(self.server_url).rstrip("/"))
415494
return ServerMetadataDiscovery(
416495
issuer=base_url,
@@ -423,10 +502,11 @@ def _default_metadata(self) -> ServerMetadataDiscovery:
423502
)
424503

425504
async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None:
426-
discovery_url = self._build_discovery_url()
427-
505+
"""
506+
Use RFC 8414 to discover the authorization server metadata.
507+
"""
428508
try:
429-
response = await self.http_client.get(str(discovery_url))
509+
response = await self.http_client.get(str(self.discovery_url))
430510
if response.status_code == 404:
431511
return None
432512
response.raise_for_status()
@@ -439,31 +519,12 @@ async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None:
439519
logger.error(f"Error during auth metadata discovery: {e}")
440520
raise
441521

442-
def _build_discovery_url(self) -> AnyHttpUrl:
443-
base_url = str(self.server_url).rstrip("/")
444-
parsed_url = urlparse(base_url)
445-
# HTTPS is required by RFC 8414
446-
discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}"
447-
return AnyHttpUrl(discovery_url)
448-
449522
async def dynamic_client_registration(
450523
self, client_metadata: ClientMetadata, registration_endpoint: AnyHttpUrl
451524
) -> DynamicClientRegistration | None:
452525
"""
453526
Register a client dynamically with an OAuth 2.0 authorization server
454527
following RFC 7591.
455-
456-
Args:
457-
client_metadata: Typed client registration metadata
458-
registration_endpoint: Where to register clients.
459-
If None, will use discovery
460-
461-
Returns:
462-
DynamicClientRegistrationResponse if successful, None otherwise
463-
464-
Raises:
465-
httpx.HTTPStatusError: If the server returns an error status code
466-
Exception: For other errors during registration
467528
"""
468529
headers = {"Content-Type": "application/json", "Accept": "application/json"}
469530

@@ -493,3 +554,145 @@ async def dynamic_client_registration(
493554
logger.error(f"Unexpected error during registration: {e}")
494555

495556
return None
557+
558+
async def exchange_authorization(
559+
self,
560+
metadata: ServerMetadataDiscovery,
561+
registration: DynamicClientRegistration,
562+
code_verifier: str,
563+
authorization_code: str,
564+
) -> AccessToken:
565+
"""Exchange an authorization code for an access token using OAuth 2.1 with PKCE.
566+
567+
Args:
568+
registration: The client registration information
569+
code_verifier: The PKCE code verifier used to generate the code challenge
570+
authorization_code: The authorization code received from the authorization
571+
server
572+
573+
Returns:
574+
AccessToken: The resulting access token
575+
576+
Raises:
577+
GrantNotSupported: If the grant type is not supported
578+
httpx.HTTPStatusError: If the token endpoint request fails
579+
"""
580+
if self.GRANT_TYPE not in (registration.grant_types or []):
581+
raise GrantNotSupported(f"Grant type {self.GRANT_TYPE} not supported")
582+
583+
code_verifier = self.provider.code_verifier()
584+
# Get token endpoint from server metadata or use default
585+
token_endpoint = str(metadata.token_endpoint)
586+
587+
# Prepare token request parameters
588+
data = {
589+
"grant_type": self.GRANT_TYPE,
590+
"code": authorization_code,
591+
"redirect_uri": str(self.provider.redirect_url),
592+
"client_id": registration.client_id,
593+
"code_verifier": code_verifier,
594+
}
595+
596+
# Add client secret if available (optional in OAuth 2.1)
597+
if registration.client_secret:
598+
data["client_secret"] = registration.client_secret
599+
600+
headers = {
601+
"Content-Type": "application/x-www-form-urlencoded",
602+
"Accept": "application/json",
603+
}
604+
605+
try:
606+
response = await self.http_client.post(
607+
token_endpoint, data=data, headers=headers
608+
)
609+
response.raise_for_status()
610+
token_data = response.json()
611+
612+
# Create and return the token
613+
return AccessToken(**token_data)
614+
615+
except httpx.HTTPStatusError as e:
616+
logger.error(f"HTTP error during token exchange: {e.response.status_code}")
617+
if e.response.content:
618+
try:
619+
error_data = json.loads(e.response.content)
620+
logger.error(f"Error details: {error_data}")
621+
except json.JSONDecodeError:
622+
logger.error(f"Error content: {e.response.content}")
623+
raise
624+
except Exception as e:
625+
logger.error(f"Unexpected error during token exchange: {e}")
626+
raise
627+
628+
async def auth(self, authorization_code: str, code_verifier: str) -> AccessToken:
629+
"""
630+
Complete the OAuth 2.1 authorization flow by exchanging authorization code
631+
for tokens.
632+
633+
Args:
634+
authorization_code: The authorization code received from the authorization
635+
server
636+
code_verifier: The PKCE code verifier used to generate the code challenge
637+
638+
Returns:
639+
AccessToken: The resulting access token
640+
"""
641+
metadata = await self.discover_auth_metadata() or self.default_metadata()
642+
registration = await self._obtain_client(metadata)
643+
644+
code_verifier = self.provider.code_verifier()
645+
646+
authorization_url = self.get_authorization_url(
647+
metadata.authorization_endpoint,
648+
self.provider.redirect_url,
649+
registration.client_id,
650+
code_verifier,
651+
self.scope,
652+
)
653+
654+
await self.provider.open_user_agent(AnyHttpUrl(authorization_url))
655+
656+
return await self.exchange_authorization(
657+
metadata, registration, code_verifier, authorization_code
658+
)
659+
660+
def get_authorization_url(
661+
self,
662+
authorization_endpoint: AnyHttpUrl,
663+
redirect_uri: AnyHttpUrl,
664+
client_id: str,
665+
code_verifier: str,
666+
scope: str | None = None,
667+
) -> AnyHttpUrl:
668+
"""Generate an OAuth 2.1 authorization URL for the user agent.
669+
670+
This method generates a URL that the user agent (browser) should visit to
671+
authenticate the user and authorize the application. It includes PKCE
672+
(Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1.
673+
"""
674+
# Create a custom verifier for this authorization request
675+
code_verifier = self.provider.code_verifier()
676+
677+
# Generate code challenge from verifier using SHA-256
678+
code_challenge = (
679+
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
680+
.decode()
681+
.rstrip("=")
682+
)
683+
684+
# Build authorization URL with necessary parameters
685+
params = {
686+
"response_type": "code",
687+
"client_id": client_id,
688+
"redirect_uri": str(redirect_uri),
689+
"code_challenge": code_challenge,
690+
"code_challenge_method": "S256",
691+
}
692+
693+
# Add scope if provided or use the one from registration
694+
if scope:
695+
params["scope"] = scope
696+
697+
# Construct the full authorization URL
698+
return AnyHttpUrl(f"{authorization_endpoint}?{urlencode(params)}")

0 commit comments

Comments
 (0)