6
6
authorization specification.
7
7
"""
8
8
9
+ import base64
10
+ import hashlib
9
11
import json
10
12
import logging
11
13
from datetime import datetime , timedelta
12
14
from typing import Any , Protocol
13
- from urllib .parse import urlparse
15
+ from urllib .parse import urlencode , urlparse
14
16
15
17
import httpx
16
18
from pydantic import AnyHttpUrl , BaseModel , ConfigDict , Field
@@ -373,7 +375,49 @@ class OAuthClientProvider(Protocol):
373
375
@property
374
376
def client_metadata (self ) -> ClientMetadata : ...
375
377
376
- def save_client_information (self , metadata : DynamicClientRegistration ) -> None : ...
378
+ @property
379
+ def redirect_url (self ) -> AnyHttpUrl : ...
380
+
381
+ async def open_user_agent (self , url : AnyHttpUrl ) -> None :
382
+ """
383
+ Opens the user agent to the given URL.
384
+ """
385
+ ...
386
+
387
+ async def client_registration (
388
+ self , endpoint : AnyHttpUrl
389
+ ) -> DynamicClientRegistration | None :
390
+ """
391
+ Loads the client registration for the given endpoint.
392
+ """
393
+ ...
394
+
395
+ async def store_client_registration (
396
+ self , endpoint : AnyHttpUrl , metadata : DynamicClientRegistration
397
+ ) -> None :
398
+ """
399
+ Stores the client registration to be retreived for the next session
400
+ """
401
+ ...
402
+
403
+ def code_verifier (self ) -> str :
404
+ """
405
+ Loads the PKCE code verifier for the current session.
406
+ See https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1
407
+ """
408
+ ...
409
+
410
+ async def token (self ) -> AccessToken | None :
411
+ """
412
+ Loads the token for the current session.
413
+ """
414
+ ...
415
+
416
+ async def store_token (self , token : AccessToken ) -> None :
417
+ """
418
+ Stores the token to be retreived for the next session
419
+ """
420
+ ...
377
421
378
422
379
423
class NotFoundError (Exception ):
@@ -388,29 +432,64 @@ class RegistrationFailedError(Exception):
388
432
pass
389
433
390
434
435
+ class GrantNotSupported (Exception ):
436
+ """Exception raised when a grant type is not supported."""
437
+
438
+ pass
439
+
440
+
391
441
class OAuthClient :
392
442
WELL_KNOWN = "/.well-known/oauth-authorization-server"
393
-
394
- def __init__ (self , server_url : AnyHttpUrl , provider : OAuthClientProvider ):
443
+ GRANT_TYPE : str = "authorization_code"
444
+
445
+ def __init__ (
446
+ self ,
447
+ server_url : AnyHttpUrl ,
448
+ provider : OAuthClientProvider ,
449
+ scope : str | None = None ,
450
+ ):
395
451
self .server_url = server_url
396
452
self .http_client = httpx .AsyncClient ()
397
453
self .provider = provider
398
- self ._registration : DynamicClientRegistration | None = None
454
+ self .scope = scope
399
455
400
- async def auth (self ):
401
- metadata = await self .discover_auth_metadata () or self ._default_metadata ()
456
+ @property
457
+ def discovery_url (self ) -> AnyHttpUrl :
458
+ base_url = str (self .server_url ).rstrip ("/" )
459
+ parsed_url = urlparse (base_url )
460
+ # HTTPS is required by RFC 8414
461
+ discovery_url = f"https://{ parsed_url .netloc } { self .WELL_KNOWN } "
462
+ return AnyHttpUrl (discovery_url )
463
+
464
+ async def _obtain_client (
465
+ self , metadata : ServerMetadataDiscovery
466
+ ) -> DynamicClientRegistration :
467
+ """
468
+ Obtain a client by either reading it from the OAuthProvider or registering it.
469
+ """
402
470
if metadata .registration_endpoint is None :
403
471
raise NotFoundError ("Registration endpoint not found" )
404
- self . _registration = await self . dynamic_client_registration (
405
- self .provider .client_metadata , metadata .registration_endpoint
406
- )
407
- if self . _registration is None :
408
- raise RegistrationFailedError (
409
- f"Registration at { metadata .registration_endpoint } failed"
472
+
473
+ if registration := await self .provider .client_registration ( metadata .issuer ):
474
+ return registration
475
+ else :
476
+ registration = await self . dynamic_client_registration (
477
+ self . provider . client_metadata , metadata .registration_endpoint
410
478
)
411
- self .provider .save_client_information (self ._registration )
479
+ if registration is None :
480
+ raise RegistrationFailedError (
481
+ f"Registration at { metadata .registration_endpoint } failed"
482
+ )
412
483
413
- def _default_metadata (self ) -> ServerMetadataDiscovery :
484
+ await self .provider .store_client_registration (metadata .issuer , registration )
485
+ return registration
486
+
487
+ def default_metadata (self ) -> ServerMetadataDiscovery :
488
+ """
489
+ Returns default endpoints as specified in
490
+ https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/
491
+ for the server.
492
+ """
414
493
base_url = AnyHttpUrl (str (self .server_url ).rstrip ("/" ))
415
494
return ServerMetadataDiscovery (
416
495
issuer = base_url ,
@@ -423,10 +502,11 @@ def _default_metadata(self) -> ServerMetadataDiscovery:
423
502
)
424
503
425
504
async def discover_auth_metadata (self ) -> ServerMetadataDiscovery | None :
426
- discovery_url = self ._build_discovery_url ()
427
-
505
+ """
506
+ Use RFC 8414 to discover the authorization server metadata.
507
+ """
428
508
try :
429
- response = await self .http_client .get (str (discovery_url ))
509
+ response = await self .http_client .get (str (self . discovery_url ))
430
510
if response .status_code == 404 :
431
511
return None
432
512
response .raise_for_status ()
@@ -439,31 +519,12 @@ async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None:
439
519
logger .error (f"Error during auth metadata discovery: { e } " )
440
520
raise
441
521
442
- def _build_discovery_url (self ) -> AnyHttpUrl :
443
- base_url = str (self .server_url ).rstrip ("/" )
444
- parsed_url = urlparse (base_url )
445
- # HTTPS is required by RFC 8414
446
- discovery_url = f"https://{ parsed_url .netloc } { self .WELL_KNOWN } "
447
- return AnyHttpUrl (discovery_url )
448
-
449
522
async def dynamic_client_registration (
450
523
self , client_metadata : ClientMetadata , registration_endpoint : AnyHttpUrl
451
524
) -> DynamicClientRegistration | None :
452
525
"""
453
526
Register a client dynamically with an OAuth 2.0 authorization server
454
527
following RFC 7591.
455
-
456
- Args:
457
- client_metadata: Typed client registration metadata
458
- registration_endpoint: Where to register clients.
459
- If None, will use discovery
460
-
461
- Returns:
462
- DynamicClientRegistrationResponse if successful, None otherwise
463
-
464
- Raises:
465
- httpx.HTTPStatusError: If the server returns an error status code
466
- Exception: For other errors during registration
467
528
"""
468
529
headers = {"Content-Type" : "application/json" , "Accept" : "application/json" }
469
530
@@ -493,3 +554,145 @@ async def dynamic_client_registration(
493
554
logger .error (f"Unexpected error during registration: { e } " )
494
555
495
556
return None
557
+
558
+ async def exchange_authorization (
559
+ self ,
560
+ metadata : ServerMetadataDiscovery ,
561
+ registration : DynamicClientRegistration ,
562
+ code_verifier : str ,
563
+ authorization_code : str ,
564
+ ) -> AccessToken :
565
+ """Exchange an authorization code for an access token using OAuth 2.1 with PKCE.
566
+
567
+ Args:
568
+ registration: The client registration information
569
+ code_verifier: The PKCE code verifier used to generate the code challenge
570
+ authorization_code: The authorization code received from the authorization
571
+ server
572
+
573
+ Returns:
574
+ AccessToken: The resulting access token
575
+
576
+ Raises:
577
+ GrantNotSupported: If the grant type is not supported
578
+ httpx.HTTPStatusError: If the token endpoint request fails
579
+ """
580
+ if self .GRANT_TYPE not in (registration .grant_types or []):
581
+ raise GrantNotSupported (f"Grant type { self .GRANT_TYPE } not supported" )
582
+
583
+ code_verifier = self .provider .code_verifier ()
584
+ # Get token endpoint from server metadata or use default
585
+ token_endpoint = str (metadata .token_endpoint )
586
+
587
+ # Prepare token request parameters
588
+ data = {
589
+ "grant_type" : self .GRANT_TYPE ,
590
+ "code" : authorization_code ,
591
+ "redirect_uri" : str (self .provider .redirect_url ),
592
+ "client_id" : registration .client_id ,
593
+ "code_verifier" : code_verifier ,
594
+ }
595
+
596
+ # Add client secret if available (optional in OAuth 2.1)
597
+ if registration .client_secret :
598
+ data ["client_secret" ] = registration .client_secret
599
+
600
+ headers = {
601
+ "Content-Type" : "application/x-www-form-urlencoded" ,
602
+ "Accept" : "application/json" ,
603
+ }
604
+
605
+ try :
606
+ response = await self .http_client .post (
607
+ token_endpoint , data = data , headers = headers
608
+ )
609
+ response .raise_for_status ()
610
+ token_data = response .json ()
611
+
612
+ # Create and return the token
613
+ return AccessToken (** token_data )
614
+
615
+ except httpx .HTTPStatusError as e :
616
+ logger .error (f"HTTP error during token exchange: { e .response .status_code } " )
617
+ if e .response .content :
618
+ try :
619
+ error_data = json .loads (e .response .content )
620
+ logger .error (f"Error details: { error_data } " )
621
+ except json .JSONDecodeError :
622
+ logger .error (f"Error content: { e .response .content } " )
623
+ raise
624
+ except Exception as e :
625
+ logger .error (f"Unexpected error during token exchange: { e } " )
626
+ raise
627
+
628
+ async def auth (self , authorization_code : str , code_verifier : str ) -> AccessToken :
629
+ """
630
+ Complete the OAuth 2.1 authorization flow by exchanging authorization code
631
+ for tokens.
632
+
633
+ Args:
634
+ authorization_code: The authorization code received from the authorization
635
+ server
636
+ code_verifier: The PKCE code verifier used to generate the code challenge
637
+
638
+ Returns:
639
+ AccessToken: The resulting access token
640
+ """
641
+ metadata = await self .discover_auth_metadata () or self .default_metadata ()
642
+ registration = await self ._obtain_client (metadata )
643
+
644
+ code_verifier = self .provider .code_verifier ()
645
+
646
+ authorization_url = self .get_authorization_url (
647
+ metadata .authorization_endpoint ,
648
+ self .provider .redirect_url ,
649
+ registration .client_id ,
650
+ code_verifier ,
651
+ self .scope ,
652
+ )
653
+
654
+ await self .provider .open_user_agent (AnyHttpUrl (authorization_url ))
655
+
656
+ return await self .exchange_authorization (
657
+ metadata , registration , code_verifier , authorization_code
658
+ )
659
+
660
+ def get_authorization_url (
661
+ self ,
662
+ authorization_endpoint : AnyHttpUrl ,
663
+ redirect_uri : AnyHttpUrl ,
664
+ client_id : str ,
665
+ code_verifier : str ,
666
+ scope : str | None = None ,
667
+ ) -> AnyHttpUrl :
668
+ """Generate an OAuth 2.1 authorization URL for the user agent.
669
+
670
+ This method generates a URL that the user agent (browser) should visit to
671
+ authenticate the user and authorize the application. It includes PKCE
672
+ (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1.
673
+ """
674
+ # Create a custom verifier for this authorization request
675
+ code_verifier = self .provider .code_verifier ()
676
+
677
+ # Generate code challenge from verifier using SHA-256
678
+ code_challenge = (
679
+ base64 .urlsafe_b64encode (hashlib .sha256 (code_verifier .encode ()).digest ())
680
+ .decode ()
681
+ .rstrip ("=" )
682
+ )
683
+
684
+ # Build authorization URL with necessary parameters
685
+ params = {
686
+ "response_type" : "code" ,
687
+ "client_id" : client_id ,
688
+ "redirect_uri" : str (redirect_uri ),
689
+ "code_challenge" : code_challenge ,
690
+ "code_challenge_method" : "S256" ,
691
+ }
692
+
693
+ # Add scope if provided or use the one from registration
694
+ if scope :
695
+ params ["scope" ] = scope
696
+
697
+ # Construct the full authorization URL
698
+ return AnyHttpUrl (f"{ authorization_endpoint } ?{ urlencode (params )} " )
0 commit comments