Skip to content

Commit 5dc6034

Browse files
authored
PYTHON-2834 Direct read/write retries to another mongos if possible (mongodb#1421)
1 parent b0cd7d2 commit 5dc6034

File tree

6 files changed

+147
-4
lines changed

6 files changed

+147
-4
lines changed

pymongo/mongo_client.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def _select_server(
12771277
server_selector: Callable[[Selection], Selection],
12781278
session: Optional[ClientSession],
12791279
address: Optional[_Address] = None,
1280+
deprioritized_servers: Optional[list[Server]] = None,
12801281
) -> Server:
12811282
"""Select a server to run an operation on this client.
12821283
@@ -1300,7 +1301,9 @@ def _select_server(
13001301
if not server:
13011302
raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031
13021303
else:
1303-
server = topology.select_server(server_selector)
1304+
server = topology.select_server(
1305+
server_selector, deprioritized_servers=deprioritized_servers
1306+
)
13041307
return server
13051308
except PyMongoError as exc:
13061309
# Server selection errors in a transaction are transient.
@@ -2291,6 +2294,7 @@ def __init__(
22912294
)
22922295
self._address = address
22932296
self._server: Server = None # type: ignore
2297+
self._deprioritized_servers: list[Server] = []
22942298

22952299
def run(self) -> T:
22962300
"""Runs the supplied func() and attempts a retry
@@ -2359,6 +2363,9 @@ def run(self) -> T:
23592363
if self._last_error is None:
23602364
self._last_error = exc
23612365

2366+
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
2367+
self._deprioritized_servers.append(self._server)
2368+
23622369
def _is_not_eligible_for_retry(self) -> bool:
23632370
"""Checks if the exchange is not eligible for retry"""
23642371
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@@ -2397,7 +2404,10 @@ def _get_server(self) -> Server:
23972404
Abstraction to connect to server
23982405
"""
23992406
return self._client._select_server(
2400-
self._server_selector, self._session, address=self._address
2407+
self._server_selector,
2408+
self._session,
2409+
address=self._address,
2410+
deprioritized_servers=self._deprioritized_servers,
24012411
)
24022412

24032413
def _write(self) -> T:

pymongo/topology.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ def _select_server(
282282
selector: Callable[[Selection], Selection],
283283
server_selection_timeout: Optional[float] = None,
284284
address: Optional[_Address] = None,
285+
deprioritized_servers: Optional[list[Server]] = None,
285286
) -> Server:
286287
servers = self.select_servers(selector, server_selection_timeout, address)
288+
servers = _filter_servers(servers, deprioritized_servers)
287289
if len(servers) == 1:
288290
return servers[0]
289291
server1, server2 = random.sample(servers, 2)
@@ -297,9 +299,12 @@ def select_server(
297299
selector: Callable[[Selection], Selection],
298300
server_selection_timeout: Optional[float] = None,
299301
address: Optional[_Address] = None,
302+
deprioritized_servers: Optional[list[Server]] = None,
300303
) -> Server:
301304
"""Like select_servers, but choose a random server if several match."""
302-
server = self._select_server(selector, server_selection_timeout, address)
305+
server = self._select_server(
306+
selector, server_selection_timeout, address, deprioritized_servers
307+
)
303308
if _csot.get_timeout():
304309
_csot.set_rtt(server.description.min_round_trip_time)
305310
return server
@@ -931,3 +936,16 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
931936
if current_tv["processId"] != new_tv["processId"]:
932937
return False
933938
return current_tv["counter"] > new_tv["counter"]
939+
940+
941+
def _filter_servers(
942+
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
943+
) -> list[Server]:
944+
"""Filter out deprioritized servers from a list of server candidates."""
945+
if not deprioritized_servers:
946+
return candidates
947+
948+
filtered = [server for server in candidates if server not in deprioritized_servers]
949+
950+
# If not possible to pick a prioritized server, return the original list
951+
return filtered or candidates

test/test_retryable_reads.py

+49
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import sys
2121
import threading
2222

23+
from bson import SON
24+
from pymongo.errors import AutoReconnect
25+
2326
sys.path[0:0] = [""]
2427

2528
from test import (
@@ -31,9 +34,12 @@
3134
)
3235
from test.utils import (
3336
CMAPListener,
37+
EventListener,
3438
OvertCommandListener,
3539
SpecTestCreator,
40+
rs_client,
3641
rs_or_single_client,
42+
set_fail_point,
3743
)
3844
from test.utils_spec_runner import SpecRunner
3945

@@ -221,5 +227,48 @@ def test_pool_paused_error_is_retryable(self):
221227
self.assertEqual(1, len(failed), msg)
222228

223229

230+
class TestRetryableReads(IntegrationTest):
231+
@client_context.require_multiple_mongoses
232+
@client_context.require_failCommand_fail_point
233+
def test_retryable_reads_in_sharded_cluster_multiple_available(self):
234+
fail_command = {
235+
"configureFailPoint": "failCommand",
236+
"mode": {"times": 1},
237+
"data": {
238+
"failCommands": ["find"],
239+
"closeConnection": True,
240+
"appName": "retryableReadTest",
241+
},
242+
}
243+
244+
mongos_clients = []
245+
246+
for mongos in client_context.mongos_seeds().split(","):
247+
client = rs_or_single_client(mongos)
248+
set_fail_point(client, fail_command)
249+
self.addCleanup(client.close)
250+
mongos_clients.append(client)
251+
252+
listener = OvertCommandListener()
253+
client = rs_or_single_client(
254+
client_context.mongos_seeds(),
255+
appName="retryableReadTest",
256+
event_listeners=[listener],
257+
retryReads=True,
258+
)
259+
260+
with self.fail_point(fail_command):
261+
with self.assertRaises(AutoReconnect):
262+
client.t.t.find_one({})
263+
264+
# Disable failpoints on each mongos
265+
for client in mongos_clients:
266+
fail_command["mode"] = "off"
267+
set_fail_point(client, fail_command)
268+
269+
self.assertEqual(len(listener.failed_events), 2)
270+
self.assertEqual(len(listener.succeeded_events), 0)
271+
272+
224273
if __name__ == "__main__":
225274
unittest.main()

test/test_retryable_writes.py

+42
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OvertCommandListener,
3232
SpecTestCreator,
3333
rs_or_single_client,
34+
set_fail_point,
3435
)
3536
from test.utils_spec_runner import SpecRunner
3637
from test.version import Version
@@ -40,6 +41,7 @@
4041
from bson.raw_bson import RawBSONDocument
4142
from bson.son import SON
4243
from pymongo.errors import (
44+
AutoReconnect,
4345
ConnectionFailure,
4446
OperationFailure,
4547
ServerSelectionTimeoutError,
@@ -469,6 +471,46 @@ def test_batch_splitting_retry_fails(self):
469471
self.assertEqual(final_txn, expected_txn)
470472
self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1})
471473

474+
@client_context.require_multiple_mongoses
475+
@client_context.require_failCommand_fail_point
476+
def test_retryable_writes_in_sharded_cluster_multiple_available(self):
477+
fail_command = {
478+
"configureFailPoint": "failCommand",
479+
"mode": {"times": 1},
480+
"data": {
481+
"failCommands": ["insert"],
482+
"closeConnection": True,
483+
"appName": "retryableWriteTest",
484+
},
485+
}
486+
487+
mongos_clients = []
488+
489+
for mongos in client_context.mongos_seeds().split(","):
490+
client = rs_or_single_client(mongos)
491+
set_fail_point(client, fail_command)
492+
self.addCleanup(client.close)
493+
mongos_clients.append(client)
494+
495+
listener = OvertCommandListener()
496+
client = rs_or_single_client(
497+
client_context.mongos_seeds(),
498+
appName="retryableWriteTest",
499+
event_listeners=[listener],
500+
retryWrites=True,
501+
)
502+
503+
with self.assertRaises(AutoReconnect):
504+
client.t.t.insert_one({"x": 1})
505+
506+
# Disable failpoints on each mongos
507+
for client in mongos_clients:
508+
fail_command["mode"] = "off"
509+
set_fail_point(client, fail_command)
510+
511+
self.assertEqual(len(listener.failed_events), 2)
512+
self.assertEqual(len(listener.succeeded_events), 0)
513+
472514

473515
class TestWriteConcernError(IntegrationTest):
474516
RUN_ON_LOAD_BALANCER = True

test/test_topology.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030
from pymongo.monitor import Monitor
3131
from pymongo.pool import PoolOptions
3232
from pymongo.read_preferences import ReadPreference, Secondary
33+
from pymongo.server import Server
3334
from pymongo.server_description import ServerDescription
3435
from pymongo.server_selectors import any_server_selector, writable_server_selector
3536
from pymongo.server_type import SERVER_TYPE
3637
from pymongo.settings import TopologySettings
37-
from pymongo.topology import Topology, _ErrorContext
38+
from pymongo.topology import Topology, _ErrorContext, _filter_servers
3839
from pymongo.topology_description import TOPOLOGY_TYPE
3940

4041

@@ -685,6 +686,23 @@ def test_unexpected_load_balancer(self):
685686
self.assertNotIn(("a", 27017), t.description.server_descriptions())
686687
self.assertEqual(t.description.topology_type_name, "Unknown")
687688

689+
def test_filtered_server_selection(self):
690+
s1 = Server(ServerDescription(("localhost", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
691+
s2 = Server(ServerDescription(("localhost2", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
692+
servers = [s1, s2]
693+
694+
result = _filter_servers(servers, deprioritized_servers=[s2])
695+
self.assertEqual(result, [s1])
696+
697+
result = _filter_servers(servers, deprioritized_servers=[s1, s2])
698+
self.assertEqual(result, servers)
699+
700+
result = _filter_servers(servers, deprioritized_servers=[])
701+
self.assertEqual(result, servers)
702+
703+
result = _filter_servers(servers)
704+
self.assertEqual(result, servers)
705+
688706

689707
def wait_for_primary(topology):
690708
"""Wait for a Topology to discover a writable server.

test/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -1153,3 +1153,9 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
11531153
raise AssertionError(f"Unsupported cursorType: {cursor_type}")
11541154
else:
11551155
arguments[c2s] = arguments.pop(arg_name)
1156+
1157+
1158+
def set_fail_point(client, command_args):
1159+
cmd = SON([("configureFailPoint", "failCommand")])
1160+
cmd.update(command_args)
1161+
client.admin.command(cmd)

0 commit comments

Comments
 (0)