Skip to content

PYTHON-2834 Direct read/write retries to another mongos if possible #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 14, 2023
14 changes: 12 additions & 2 deletions pymongo/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ def _select_server(
server_selector: Callable[[Selection], Selection],
session: Optional[ClientSession],
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
"""Select a server to run an operation on this client.

Expand All @@ -1300,7 +1301,9 @@ def _select_server(
if not server:
raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031
else:
server = topology.select_server(server_selector)
server = topology.select_server(
server_selector, deprioritized_servers=deprioritized_servers
)
return server
except PyMongoError as exc:
# Server selection errors in a transaction are transient.
Expand Down Expand Up @@ -2291,6 +2294,7 @@ def __init__(
)
self._address = address
self._server: Server = None # type: ignore
self._deprioritized_servers: list[Server] = []

def run(self) -> T:
"""Runs the supplied func() and attempts a retry
Expand Down Expand Up @@ -2359,6 +2363,9 @@ def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
Expand Down Expand Up @@ -2397,7 +2404,10 @@ def _get_server(self) -> Server:
Abstraction to connect to server
"""
return self._client._select_server(
self._server_selector, self._session, address=self._address
self._server_selector,
self._session,
address=self._address,
deprioritized_servers=self._deprioritized_servers,
)

def _write(self) -> T:
Expand Down
26 changes: 22 additions & 4 deletions pymongo/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ def _select_server(
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
servers = self.select_servers(selector, server_selection_timeout, address)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
filtered_servers = _filter_servers(servers, deprioritized_servers)
Copy link
Member

@ShaneHarvey ShaneHarvey Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename filtered_servers->servers so the rest of this code doesn't have to change?

if len(filtered_servers) == 1:
return filtered_servers[0]
server1, server2 = random.sample(filtered_servers, 2)
if server1.pool.operation_count <= server2.pool.operation_count:
return server1
else:
Expand All @@ -297,9 +299,12 @@ def select_server(
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
"""Like select_servers, but choose a random server if several match."""
server = self._select_server(selector, server_selection_timeout, address)
server = self._select_server(
selector, server_selection_timeout, address, deprioritized_servers
)
if _csot.get_timeout():
_csot.set_rtt(server.description.min_round_trip_time)
return server
Expand Down Expand Up @@ -931,3 +936,16 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
49 changes: 49 additions & 0 deletions test/test_retryable_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import sys
import threading

from bson import SON
from pymongo.errors import AutoReconnect

sys.path[0:0] = [""]

from test import (
Expand All @@ -31,9 +34,12 @@
)
from test.utils import (
CMAPListener,
EventListener,
OvertCommandListener,
SpecTestCreator,
rs_client,
rs_or_single_client,
set_fail_point,
)
from test.utils_spec_runner import SpecRunner

Expand Down Expand Up @@ -221,5 +227,48 @@ def test_pool_paused_error_is_retryable(self):
self.assertEqual(1, len(failed), msg)


class TestRetryableReads(IntegrationTest):
@client_context.require_multiple_mongoses
@client_context.require_failCommand_fail_point
def test_retryable_reads_in_sharded_cluster_multiple_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"closeConnection": True,
"appName": "retryableReadTest",
},
}

mongos_clients = []

for mongos in client_context.mongos_seeds().split(","):
client = rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)

listener = OvertCommandListener()
client = rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableReadTest",
event_listeners=[listener],
retryReads=True,
)

with self.fail_point(fail_command):
with self.assertRaises(AutoReconnect):
client.t.t.find_one({})

# Disable failpoints on each mongos
for client in mongos_clients:
fail_command["mode"] = "off"
set_fail_point(client, fail_command)

self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions test/test_retryable_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
OvertCommandListener,
SpecTestCreator,
rs_or_single_client,
set_fail_point,
)
from test.utils_spec_runner import SpecRunner
from test.version import Version
Expand All @@ -40,6 +41,7 @@
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
ServerSelectionTimeoutError,
Expand Down Expand Up @@ -469,6 +471,46 @@ def test_batch_splitting_retry_fails(self):
self.assertEqual(final_txn, expected_txn)
self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1})

@client_context.require_multiple_mongoses
@client_context.require_failCommand_fail_point
def test_retryable_writes_in_sharded_cluster_multiple_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"closeConnection": True,
"appName": "retryableWriteTest",
},
}

mongos_clients = []

for mongos in client_context.mongos_seeds().split(","):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for mongos in client_context.mongoses:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That change causes the following error:

h = ('localhost', 27017)

    def _connection_string(h):
>       if h.startswith(("mongodb://", "mongodb+srv://")):
E       AttributeError: 'tuple' object has no attribute 'startswith'

test/utils.py:544: AttributeError

client = rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)

listener = OvertCommandListener()
client = rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableWriteTest",
event_listeners=[listener],
retryWrites=True,
)

with self.assertRaises(AutoReconnect):
client.t.t.insert_one({"x": 1})

# Disable failpoints on each mongos
for client in mongos_clients:
fail_command["mode"] = "off"
set_fail_point(client, fail_command)

self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)


class TestWriteConcernError(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
Expand Down
20 changes: 19 additions & 1 deletion test/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
from pymongo.monitor import Monitor
from pymongo.pool import PoolOptions
from pymongo.read_preferences import ReadPreference, Secondary
from pymongo.server import Server
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import any_server_selector, writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.settings import TopologySettings
from pymongo.topology import Topology, _ErrorContext
from pymongo.topology import Topology, _ErrorContext, _filter_servers
from pymongo.topology_description import TOPOLOGY_TYPE


Expand Down Expand Up @@ -681,6 +682,23 @@ def test_unexpected_load_balancer(self):
self.assertNotIn(("a", 27017), t.description.server_descriptions())
self.assertEqual(t.description.topology_type_name, "Unknown")

def test_filtered_server_selection(self):
s1 = Server(ServerDescription(("localhost", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
s2 = Server(ServerDescription(("localhost2", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
servers = [s1, s2]

result = _filter_servers(servers, deprioritized_servers=[s2])
self.assertEqual(result, [s1])

result = _filter_servers(servers, deprioritized_servers=[s1, s2])
self.assertEqual(result, servers)

result = _filter_servers(servers, deprioritized_servers=[])
self.assertEqual(result, servers)

result = _filter_servers(servers)
self.assertEqual(result, servers)


def wait_for_primary(topology):
"""Wait for a Topology to discover a writable server.
Expand Down
6 changes: 6 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,3 +1153,9 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
raise AssertionError(f"Unsupported cursorType: {cursor_type}")
else:
arguments[c2s] = arguments.pop(arg_name)


def set_fail_point(client, command_args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the with self.fail_point() helper instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set the fail point on each individual mongos to ensure the failure occurs on every node. with self.fail_point() does not set the fail point correctly.

cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
client.admin.command(cmd)