|
1 |
| -import os |
2 | 1 | import threading
|
3 | 2 | import time
|
4 |
| -from datetime import datetime, timezone |
5 |
| -from typing import Any, Dict, Optional |
| 3 | +from datetime import UTC, datetime |
| 4 | +from typing import Any |
6 | 5 |
|
7 | 6 | import httpx
|
8 | 7 | import structlog
|
9 | 8 | from fastapi import Depends, HTTPException, status
|
10 | 9 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
11 |
| -from jose import jwt, jwk, JWTError |
| 10 | +from jose import JWTError, jwk, jwt |
12 | 11 | from pydantic import BaseModel
|
13 | 12 |
|
14 | 13 | from agent_memory_server.config import settings
|
|
19 | 18 |
|
20 | 19 | class UserInfo(BaseModel):
|
21 | 20 | sub: str
|
22 |
| - aud: Optional[str] = None |
23 |
| - scope: Optional[str] = None |
24 |
| - exp: Optional[int] = None |
25 |
| - iat: Optional[int] = None |
26 |
| - iss: Optional[str] = None |
27 |
| - email: Optional[str] = None |
28 |
| - roles: Optional[list[str]] = None |
| 21 | + aud: str | None = None |
| 22 | + scope: str | None = None |
| 23 | + exp: int | None = None |
| 24 | + iat: int | None = None |
| 25 | + iss: str | None = None |
| 26 | + email: str | None = None |
| 27 | + roles: list[str] | None = None |
29 | 28 |
|
30 | 29 |
|
31 | 30 | class JWKSCache:
|
32 | 31 | def __init__(self, cache_duration: int = 3600):
|
33 |
| - self._cache: Dict[str, Any] = {} |
34 |
| - self._cache_time: Optional[float] = None |
| 32 | + self._cache: dict[str, Any] = {} |
| 33 | + self._cache_time: float | None = None |
35 | 34 | self._cache_duration = cache_duration
|
36 | 35 | self._lock = threading.Lock()
|
37 | 36 |
|
38 |
| - def get_jwks(self, jwks_url: str) -> Dict[str, Any]: |
| 37 | + def get_jwks(self, jwks_url: str) -> dict[str, Any]: |
39 | 38 | current_time = time.time()
|
40 | 39 |
|
41 | 40 | if (self._cache_time is None or
|
@@ -68,13 +67,13 @@ def get_jwks(self, jwks_url: str) -> Dict[str, Any]:
|
68 | 67 | raise HTTPException(
|
69 | 68 | status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
70 | 69 | detail=f"Unable to fetch JWKS from {jwks_url}: {str(e)}"
|
71 |
| - ) |
| 70 | + ) from e |
72 | 71 | except Exception as e:
|
73 | 72 | logger.error("Unexpected error fetching JWKS", error=str(e))
|
74 | 73 | raise HTTPException(
|
75 | 74 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
76 | 75 | detail="Internal server error while fetching JWKS"
|
77 |
| - ) |
| 76 | + ) from e |
78 | 77 |
|
79 | 78 | return self._cache
|
80 | 79 |
|
@@ -105,7 +104,7 @@ def get_public_key(token: str) -> str:
|
105 | 104 | raise HTTPException(
|
106 | 105 | status_code=status.HTTP_401_UNAUTHORIZED,
|
107 | 106 | detail="Invalid token header"
|
108 |
| - ) |
| 107 | + ) from e |
109 | 108 |
|
110 | 109 | kid = unverified_header.get("kid")
|
111 | 110 | if not kid:
|
@@ -167,7 +166,7 @@ def verify_jwt(token: str) -> UserInfo:
|
167 | 166 | options=decode_options
|
168 | 167 | )
|
169 | 168 |
|
170 |
| - current_time = int(datetime.now(timezone.utc).timestamp()) |
| 169 | + current_time = int(datetime.now(UTC).timestamp()) |
171 | 170 |
|
172 | 171 | exp = payload.get("exp")
|
173 | 172 | if exp and exp < current_time:
|
@@ -230,16 +229,16 @@ def verify_jwt(token: str) -> UserInfo:
|
230 | 229 | raise HTTPException(
|
231 | 230 | status_code=status.HTTP_401_UNAUTHORIZED,
|
232 | 231 | detail=f"Invalid JWT: {str(e)}"
|
233 |
| - ) |
| 232 | + ) from e |
234 | 233 | except Exception as e:
|
235 | 234 | logger.error("Unexpected error during JWT validation", error=str(e))
|
236 | 235 | raise HTTPException(
|
237 | 236 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
238 | 237 | detail="Internal server error during authentication"
|
239 |
| - ) |
| 238 | + ) from e |
240 | 239 |
|
241 | 240 |
|
242 |
| -def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(oauth2_scheme)) -> UserInfo: |
| 241 | +def get_current_user(credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme)) -> UserInfo: |
243 | 242 | if settings.disable_auth:
|
244 | 243 | logger.debug("Authentication disabled, returning default user")
|
245 | 244 | return UserInfo(
|
|
0 commit comments