Skip to content

Commit a48d77e

Browse files
committed
Fix tests
1 parent b5074e2 commit a48d77e

File tree

7 files changed

+48
-34
lines changed

7 files changed

+48
-34
lines changed

pymongo/asynchronous/auth_oidc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""MONGODB-OIDC Authentication helpers."""
1616
from __future__ import annotations
1717

18-
import threading
18+
import asyncio
1919
import time
2020
from dataclasses import dataclass, field
2121
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
@@ -36,6 +36,7 @@
3636
)
3737
from pymongo.errors import ConfigurationError, OperationFailure
3838
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
39+
from pymongo.lock import Lock, _async_create_lock
3940

4041
if TYPE_CHECKING:
4142
from pymongo.asynchronous.pool import AsyncConnection
@@ -81,7 +82,7 @@ class _OIDCAuthenticator:
8182
access_token: Optional[str] = field(default=None)
8283
idp_info: Optional[OIDCIdPInfo] = field(default=None)
8384
token_gen_id: int = field(default=0)
84-
lock: threading.Lock = field(default_factory=threading.Lock)
85+
lock: Lock = field(default_factory=_async_create_lock)
8586
last_call_time: float = field(default=0)
8687

8788
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
@@ -164,7 +165,7 @@ async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[s
164165
# Attempt to authenticate with a JwtStepRequest.
165166
return await self._sasl_continue_jwt(conn, start_resp)
166167

167-
def _get_access_token(self) -> Optional[str]:
168+
async def _get_access_token(self) -> Optional[str]:
168169
properties = self.properties
169170
cb: Union[None, OIDCCallback]
170171
resp: OIDCCallbackResult
@@ -186,7 +187,7 @@ def _get_access_token(self) -> Optional[str]:
186187
return None
187188

188189
if not prev_token and cb is not None:
189-
with self.lock:
190+
async with self.lock:
190191
# See if the token was changed while we were waiting for the
191192
# lock.
192193
new_token = self.access_token
@@ -196,7 +197,7 @@ def _get_access_token(self) -> Optional[str]:
196197
# Ensure that we are waiting a min time between callback invocations.
197198
delta = time.time() - self.last_call_time
198199
if delta < TIME_BETWEEN_CALLS_SECONDS:
199-
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
200+
await asyncio.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
200201
self.last_call_time = time.time()
201202

202203
if is_human:
@@ -211,7 +212,10 @@ def _get_access_token(self) -> Optional[str]:
211212
idp_info=self.idp_info,
212213
username=self.properties.username,
213214
)
214-
resp = cb.fetch(context)
215+
if not _IS_SYNC:
216+
resp = await asyncio.get_running_loop().run_in_executor(None, cb.fetch, context)
217+
else:
218+
resp = cb.fetch(context)
215219
if not isinstance(resp, OIDCCallbackResult):
216220
raise ValueError(
217221
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
@@ -253,13 +257,13 @@ async def _sasl_continue_jwt(
253257
start_payload: dict = bson.decode(start_resp["payload"])
254258
if "issuer" in start_payload:
255259
self.idp_info = OIDCIdPInfo(**start_payload)
256-
access_token = self._get_access_token()
260+
access_token = await self._get_access_token()
257261
conn.oidc_token_gen_id = self.token_gen_id
258262
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
259263
return await self._run_command(conn, cmd)
260264

261265
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
262-
access_token = self._get_access_token()
266+
access_token = await self._get_access_token()
263267
conn.oidc_token_gen_id = self.token_gen_id
264268
cmd = self._get_start_command({"jwt": access_token})
265269
return await self._run_command(conn, cmd)

pymongo/asynchronous/cursor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,6 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11301130
except BaseException:
11311131
await self.close()
11321132
raise
1133-
11341133
self._address = response.address
11351134
if isinstance(response, PinnedResponse):
11361135
if not self._sock_mgr:

pymongo/asynchronous/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
6464
await conn.authenticate(reauthenticate=True)
6565
else:
6666
raise
67-
return func(*args, **kwargs)
67+
return await func(*args, **kwargs)
6868
raise
6969

7070
return cast(F, inner)

pymongo/synchronous/auth_oidc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""MONGODB-OIDC Authentication helpers."""
1616
from __future__ import annotations
1717

18-
import threading
18+
import asyncio
1919
import time
2020
from dataclasses import dataclass, field
2121
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
@@ -36,6 +36,7 @@
3636
)
3737
from pymongo.errors import ConfigurationError, OperationFailure
3838
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
39+
from pymongo.lock import Lock, _create_lock
3940

4041
if TYPE_CHECKING:
4142
from pymongo.auth_shared import MongoCredential
@@ -81,7 +82,7 @@ class _OIDCAuthenticator:
8182
access_token: Optional[str] = field(default=None)
8283
idp_info: Optional[OIDCIdPInfo] = field(default=None)
8384
token_gen_id: int = field(default=0)
84-
lock: threading.Lock = field(default_factory=threading.Lock)
85+
lock: Lock = field(default_factory=_create_lock)
8586
last_call_time: float = field(default=0)
8687

8788
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
@@ -211,7 +212,10 @@ def _get_access_token(self) -> Optional[str]:
211212
idp_info=self.idp_info,
212213
username=self.properties.username,
213214
)
214-
resp = cb.fetch(context)
215+
if not _IS_SYNC:
216+
resp = asyncio.get_running_loop().run_in_executor(None, cb.fetch, context)
217+
else:
218+
resp = cb.fetch(context)
215219
if not isinstance(resp, OIDCCallbackResult):
216220
raise ValueError(
217221
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"

pymongo/synchronous/cursor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,6 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11281128
except BaseException:
11291129
self.close()
11301130
raise
1131-
11321131
self._address = response.address
11331132
if isinstance(response, PinnedResponse):
11341133
if not self._sock_mgr:

test/asynchronous/test_auth_oidc.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from contextlib import asynccontextmanager
2525
from pathlib import Path
2626
from test.asynchronous import AsyncPyMongoTestCase
27+
from test.asynchronous.helpers import ConcurrentRunner
2728
from typing import Dict
2829

2930
import pytest
3031

3132
sys.path[0:0] = [""]
3233

33-
from test.unified_format import generate_test_classes
34+
from test.asynchronous.unified_format import generate_test_classes
3435
from test.utils_shared import EventListener, OvertCommandListener
3536

3637
from bson import SON
@@ -258,10 +259,11 @@ async def test_1_7_allowed_hosts_in_connection_string_ignored(self):
258259
uri = "mongodb+srv://example.com?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D"
259260
with self.assertRaises(ConfigurationError), warnings.catch_warnings():
260261
warnings.simplefilter("ignore")
261-
_ = AsyncMongoClient(
262+
c = AsyncMongoClient(
262263
uri,
263264
authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb()),
264265
)
266+
await c.aconnect()
265267

266268
async def test_1_8_machine_idp_human_callback(self):
267269
if not os.environ.get("OIDC_IS_LOCAL"):
@@ -752,7 +754,7 @@ async def test_reauthenticate_succeeds_command(self):
752754
class TestAuthOIDCMachine(OIDCTestBase):
753755
uri: str
754756

755-
def asyncSetUp(self):
757+
async def asyncSetUp(self):
756758
self.request_called = 0
757759

758760
def create_request_cb(self, username=None, sleep=0):
@@ -791,24 +793,24 @@ async def test_1_1_callback_is_called_during_reauthentication(self):
791793
# Assert that the callback was called 1 time.
792794
self.assertEqual(self.request_called, 1)
793795

794-
# TODO REPLACE THREADS WITH TASKS
795796
async def test_1_2_callback_is_called_once_for_multiple_connections(self):
796797
# Create a ``AsyncMongoClient`` configured with a custom OIDC callback that
797798
# implements the provider logic.
798799
client = await self.create_client()
800+
await client.aconnect()
799801

800802
# Start 10 threads and run 100 find operations in each thread that all succeed.
801803
async def target():
802804
for _ in range(100):
803805
await client.test.test.find_one()
804806

805-
threads = []
806-
for _ in range(10):
807-
thread = threading.Thread(target=target)
808-
thread.start()
809-
threads.append(thread)
810-
for thread in threads:
811-
thread.join()
807+
tasks = []
808+
for i in range(10):
809+
tasks.append(ConcurrentRunner(target=target))
810+
for t in tasks:
811+
await t.start()
812+
for t in tasks:
813+
await t.join()
812814
# Assert that the callback was called 1 time.
813815
self.assertEqual(self.request_called, 1)
814816

@@ -1043,6 +1045,7 @@ async def test_4_4_speculative_authentication_should_be_ignored_on_reauthenticat
10431045
# Create an OIDC configured client that can listen for `SaslStart` commands.
10441046
listener = EventListener()
10451047
client = await self.create_client(event_listeners=[listener])
1048+
await client.aconnect()
10461049

10471050
# Preload the *Client Cache* with a valid access token to enforce Speculative Authentication.
10481051
client2 = await self.create_client()
@@ -1107,6 +1110,7 @@ async def test_speculative_auth_success(self):
11071110
client1 = await self.create_client()
11081111
await client1.test.test.find_one()
11091112
client2 = await self.create_client()
1113+
await client2.aconnect()
11101114

11111115
# Prime the cache of the second client.
11121116
client2.options.pool_options._credentials.cache.data = (

test/test_auth_oidc.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from contextlib import contextmanager
2525
from pathlib import Path
2626
from test import PyMongoTestCase
27+
from test.helpers import ConcurrentRunner
2728
from typing import Dict
2829

2930
import pytest
@@ -256,10 +257,11 @@ def test_1_7_allowed_hosts_in_connection_string_ignored(self):
256257
uri = "mongodb+srv://example.com?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D"
257258
with self.assertRaises(ConfigurationError), warnings.catch_warnings():
258259
warnings.simplefilter("ignore")
259-
_ = MongoClient(
260+
c = MongoClient(
260261
uri,
261262
authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb()),
262263
)
264+
c._connect()
263265

264266
def test_1_8_machine_idp_human_callback(self):
265267
if not os.environ.get("OIDC_IS_LOCAL"):
@@ -789,24 +791,24 @@ def test_1_1_callback_is_called_during_reauthentication(self):
789791
# Assert that the callback was called 1 time.
790792
self.assertEqual(self.request_called, 1)
791793

792-
# TODO REPLACE THREADS WITH TASKS
793794
def test_1_2_callback_is_called_once_for_multiple_connections(self):
794795
# Create a ``MongoClient`` configured with a custom OIDC callback that
795796
# implements the provider logic.
796797
client = self.create_client()
798+
client._connect()
797799

798800
# Start 10 threads and run 100 find operations in each thread that all succeed.
799801
def target():
800802
for _ in range(100):
801803
client.test.test.find_one()
802804

803-
threads = []
804-
for _ in range(10):
805-
thread = threading.Thread(target=target)
806-
thread.start()
807-
threads.append(thread)
808-
for thread in threads:
809-
thread.join()
805+
tasks = []
806+
for i in range(10):
807+
tasks.append(ConcurrentRunner(target=target))
808+
for t in tasks:
809+
t.start()
810+
for t in tasks:
811+
t.join()
810812
# Assert that the callback was called 1 time.
811813
self.assertEqual(self.request_called, 1)
812814

@@ -1041,6 +1043,7 @@ def test_4_4_speculative_authentication_should_be_ignored_on_reauthentication(se
10411043
# Create an OIDC configured client that can listen for `SaslStart` commands.
10421044
listener = EventListener()
10431045
client = self.create_client(event_listeners=[listener])
1046+
client._connect()
10441047

10451048
# Preload the *Client Cache* with a valid access token to enforce Speculative Authentication.
10461049
client2 = self.create_client()
@@ -1105,6 +1108,7 @@ def test_speculative_auth_success(self):
11051108
client1 = self.create_client()
11061109
client1.test.test.find_one()
11071110
client2 = self.create_client()
1111+
client2._connect()
11081112

11091113
# Prime the cache of the second client.
11101114
client2.options.pool_options._credentials.cache.data = (

0 commit comments

Comments
 (0)