1
+ # Always import asyncio
2
+ import asyncio
1
3
from collections .abc import Awaitable , Callable , Generator
2
4
from functools import wraps
3
5
from typing import Literal , NewType , ParamSpec , Protocol , TypeVar , cast , final
4
- # Always import asyncio
5
- import asyncio
6
+
6
7
7
8
# pragma: no cover
8
9
class AsyncLock (Protocol ):
9
10
"""A protocol for an asynchronous lock."""
11
+
10
12
def __init__ (self ) -> None : ...
11
13
async def __aenter__ (self ) -> None : ...
12
14
async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None : ...
13
15
14
16
15
17
# Define context types as literals
16
- AsyncContext = Literal [" asyncio" , " trio" , " unknown" ]
18
+ AsyncContext = Literal [' asyncio' , ' trio' , ' unknown' ]
17
19
18
20
19
21
# Functions for detecting async context
@@ -60,10 +62,10 @@ def _is_in_trio_context() -> bool:
60
62
# Early return if trio is not available
61
63
if not has_trio :
62
64
return False
63
-
65
+
64
66
# Import trio here since we already checked it's available
65
67
import trio
66
-
68
+
67
69
try :
68
70
# Will raise RuntimeError if not in trio context
69
71
trio .lowlevel .current_task ()
@@ -81,9 +83,9 @@ def detect_async_context() -> AsyncContext:
81
83
"""
82
84
# This branch is only taken when anyio is not installed
83
85
if not has_anyio or not _is_in_trio_context ():
84
- return " asyncio"
86
+ return ' asyncio'
85
87
86
- return " trio"
88
+ return ' trio'
87
89
88
90
89
91
_ValueType = TypeVar ('_ValueType' )
@@ -142,7 +144,9 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
142
144
"""We need just an awaitable to work with."""
143
145
self ._coro = coro
144
146
self ._cache : _ValueType | _Sentinel = _sentinel
145
- self ._lock : AsyncLock | None = None # Will be created lazily based on the backend
147
+ self ._lock : AsyncLock | None = (
148
+ None # Will be created lazily based on the backend
149
+ )
146
150
147
151
def __await__ (self ) -> Generator [None , None , _ValueType ]:
148
152
"""
@@ -192,14 +196,14 @@ def _create_lock(self) -> AsyncLock:
192
196
"""Create the appropriate lock based on the current async context."""
193
197
context = detect_async_context ()
194
198
195
- if context == " trio" and has_anyio :
199
+ if context == ' trio' and has_anyio :
196
200
try :
197
201
import anyio
198
202
except Exception :
199
203
# Just continue to asyncio if anyio import fails
200
204
return asyncio .Lock ()
201
205
return anyio .Lock ()
202
-
206
+
203
207
# For asyncio or unknown contexts
204
208
return asyncio .Lock ()
205
209
@@ -254,4 +258,4 @@ def decorator(
254
258
) -> _AwaitableT :
255
259
return ReAwaitable (coro (* args , ** kwargs )) # type: ignore[return-value]
256
260
257
- return decorator
261
+ return decorator
0 commit comments