Skip to content

Commit bf74e88

Browse files
JesseDeLoorerony batistaJesse De Loore
authored
Add support for target_session_attrs (#987)
This adds support for the `target_session_attrs` connection option. Co-authored-by: rony batista <[email protected]> Co-authored-by: Jesse De Loore <[email protected]>
1 parent 7443a9e commit bf74e88

File tree

7 files changed

+377
-75
lines changed

7 files changed

+377
-75
lines changed

asyncpg/_testbase/__init__.py

+90
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,93 @@ def tearDown(self):
435435
self.con = None
436436
finally:
437437
super().tearDown()
438+
439+
440+
class HotStandbyTestCase(ClusterTestCase):
441+
442+
@classmethod
443+
def setup_cluster(cls):
444+
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
445+
cls.start_cluster(
446+
cls.master_cluster,
447+
server_settings={
448+
'max_wal_senders': 10,
449+
'wal_level': 'hot_standby'
450+
}
451+
)
452+
453+
con = None
454+
455+
try:
456+
con = cls.loop.run_until_complete(
457+
cls.master_cluster.connect(
458+
database='postgres', user='postgres', loop=cls.loop))
459+
460+
cls.loop.run_until_complete(
461+
con.execute('''
462+
CREATE ROLE replication WITH LOGIN REPLICATION
463+
'''))
464+
465+
cls.master_cluster.trust_local_replication_by('replication')
466+
467+
conn_spec = cls.master_cluster.get_connection_spec()
468+
469+
cls.standby_cluster = cls.new_cluster(
470+
pg_cluster.HotStandbyCluster,
471+
cluster_kwargs={
472+
'master': conn_spec,
473+
'replication_user': 'replication'
474+
}
475+
)
476+
cls.start_cluster(
477+
cls.standby_cluster,
478+
server_settings={
479+
'hot_standby': True
480+
}
481+
)
482+
483+
finally:
484+
if con is not None:
485+
cls.loop.run_until_complete(con.close())
486+
487+
@classmethod
488+
def get_cluster_connection_spec(cls, cluster, kwargs={}):
489+
conn_spec = cluster.get_connection_spec()
490+
if kwargs.get('dsn'):
491+
conn_spec.pop('host')
492+
conn_spec.update(kwargs)
493+
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
494+
if 'database' not in conn_spec:
495+
conn_spec['database'] = 'postgres'
496+
if 'user' not in conn_spec:
497+
conn_spec['user'] = 'postgres'
498+
return conn_spec
499+
500+
@classmethod
501+
def get_connection_spec(cls, kwargs={}):
502+
primary_spec = cls.get_cluster_connection_spec(
503+
cls.master_cluster, kwargs
504+
)
505+
standby_spec = cls.get_cluster_connection_spec(
506+
cls.standby_cluster, kwargs
507+
)
508+
return {
509+
'host': [primary_spec['host'], standby_spec['host']],
510+
'port': [primary_spec['port'], standby_spec['port']],
511+
'database': primary_spec['database'],
512+
'user': primary_spec['user'],
513+
**kwargs
514+
}
515+
516+
@classmethod
517+
def connect_primary(cls, **kwargs):
518+
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
519+
return pg_connection.connect(**conn_spec, loop=cls.loop)
520+
521+
@classmethod
522+
def connect_standby(cls, **kwargs):
523+
conn_spec = cls.get_cluster_connection_spec(
524+
cls.standby_cluster,
525+
kwargs
526+
)
527+
return pg_connection.connect(**conn_spec, loop=cls.loop)

asyncpg/cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def init(self, **settings):
626626
'pg_basebackup init exited with status {:d}:\n{}'.format(
627627
process.returncode, output.decode()))
628628

629-
if self._pg_version <= (11, 0):
629+
if self._pg_version < (12, 0):
630630
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
631631
f.write(textwrap.dedent("""\
632632
standby_mode = 'on'

asyncpg/connect_utils.py

+114-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import pathlib
1515
import platform
16+
import random
1617
import re
1718
import socket
1819
import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
5657
'direct_tls',
5758
'connect_timeout',
5859
'server_settings',
60+
'target_session_attrs',
5961
])
6062

6163

@@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
260262

261263
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
262264
password, passfile, database, ssl,
263-
direct_tls, connect_timeout, server_settings):
265+
direct_tls, connect_timeout, server_settings,
266+
target_session_attrs):
264267
# `auth_hosts` is the version of host information for the purposes
265268
# of reading the pgpass file.
266269
auth_hosts = None
@@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
607610
'server_settings is expected to be None or '
608611
'a Dict[str, str]')
609612

613+
if target_session_attrs is None:
614+
615+
target_session_attrs = os.getenv(
616+
"PGTARGETSESSIONATTRS", SessionAttribute.any
617+
)
618+
try:
619+
620+
target_session_attrs = SessionAttribute(target_session_attrs)
621+
except ValueError as exc:
622+
raise exceptions.InterfaceError(
623+
"target_session_attrs is expected to be one of "
624+
"{!r}"
625+
", got {!r}".format(
626+
SessionAttribute.__members__.values, target_session_attrs
627+
)
628+
) from exc
629+
610630
params = _ConnectionParameters(
611631
user=user, password=password, database=database, ssl=ssl,
612632
sslmode=sslmode, direct_tls=direct_tls,
613-
connect_timeout=connect_timeout, server_settings=server_settings)
633+
connect_timeout=connect_timeout, server_settings=server_settings,
634+
target_session_attrs=target_session_attrs)
614635

615636
return addrs, params
616637

@@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
620641
statement_cache_size,
621642
max_cached_statement_lifetime,
622643
max_cacheable_statement_size,
623-
ssl, direct_tls, server_settings):
624-
644+
ssl, direct_tls, server_settings,
645+
target_session_attrs):
625646
local_vars = locals()
626647
for var_name in {'max_cacheable_statement_size',
627648
'max_cached_statement_lifetime',
@@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
649670
dsn=dsn, host=host, port=port, user=user,
650671
password=password, passfile=passfile, ssl=ssl,
651672
direct_tls=direct_tls, database=database,
652-
connect_timeout=timeout, server_settings=server_settings)
673+
connect_timeout=timeout, server_settings=server_settings,
674+
target_session_attrs=target_session_attrs)
653675

654676
config = _ClientConfiguration(
655677
command_timeout=command_timeout,
@@ -882,18 +904,84 @@ async def __connect_addr(
882904
return con
883905

884906

907+
class SessionAttribute(str, enum.Enum):
908+
any = 'any'
909+
primary = 'primary'
910+
standby = 'standby'
911+
prefer_standby = 'prefer-standby'
912+
read_write = "read-write"
913+
read_only = "read-only"
914+
915+
916+
def _accept_in_hot_standby(should_be_in_hot_standby: bool):
917+
"""
918+
If the server didn't report "in_hot_standby" at startup, we must determine
919+
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
920+
If the server allows a connection and states it is in recovery it must
921+
be a replica/standby server.
922+
"""
923+
async def can_be_used(connection):
924+
settings = connection.get_settings()
925+
hot_standby_status = getattr(settings, 'in_hot_standby', None)
926+
if hot_standby_status is not None:
927+
is_in_hot_standby = hot_standby_status == 'on'
928+
else:
929+
is_in_hot_standby = await connection.fetchval(
930+
"SELECT pg_catalog.pg_is_in_recovery()"
931+
)
932+
return is_in_hot_standby == should_be_in_hot_standby
933+
934+
return can_be_used
935+
936+
937+
def _accept_read_only(should_be_read_only: bool):
938+
"""
939+
Verify the server has not set default_transaction_read_only=True
940+
"""
941+
async def can_be_used(connection):
942+
settings = connection.get_settings()
943+
is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
944+
945+
if is_readonly == "on":
946+
return should_be_read_only
947+
948+
return await _accept_in_hot_standby(should_be_read_only)(connection)
949+
return can_be_used
950+
951+
952+
async def _accept_any(_):
953+
return True
954+
955+
956+
target_attrs_check = {
957+
SessionAttribute.any: _accept_any,
958+
SessionAttribute.primary: _accept_in_hot_standby(False),
959+
SessionAttribute.standby: _accept_in_hot_standby(True),
960+
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
961+
SessionAttribute.read_write: _accept_read_only(False),
962+
SessionAttribute.read_only: _accept_read_only(True),
963+
}
964+
965+
966+
async def _can_use_connection(connection, attr: SessionAttribute):
967+
can_use = target_attrs_check[attr]
968+
return await can_use(connection)
969+
970+
885971
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
886972
if loop is None:
887973
loop = asyncio.get_event_loop()
888974

889975
addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
976+
target_attr = params.target_session_attrs
890977

978+
candidates = []
979+
chosen_connection = None
891980
last_error = None
892-
addr = None
893981
for addr in addrs:
894982
before = time.monotonic()
895983
try:
896-
return await _connect_addr(
984+
conn = await _connect_addr(
897985
addr=addr,
898986
loop=loop,
899987
timeout=timeout,
@@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
902990
connection_class=connection_class,
903991
record_class=record_class,
904992
)
993+
candidates.append(conn)
994+
if await _can_use_connection(conn, target_attr):
995+
chosen_connection = conn
996+
break
905997
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
906998
last_error = ex
907999
finally:
9081000
timeout -= time.monotonic() - before
1001+
else:
1002+
if target_attr == SessionAttribute.prefer_standby and candidates:
1003+
chosen_connection = random.choice(candidates)
1004+
1005+
await asyncio.gather(
1006+
(c.close() for c in candidates if c is not chosen_connection),
1007+
return_exceptions=True
1008+
)
1009+
1010+
if chosen_connection:
1011+
return chosen_connection
9091012

910-
raise last_error
1013+
raise last_error or exceptions.TargetServerAttributeNotMatched(
1014+
'None of the hosts match the target attribute requirement '
1015+
'{!r}'.format(target_attr)
1016+
)
9111017

9121018

9131019
async def _cancel(*, loop, addr, params: _ConnectionParameters,

asyncpg/connection.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,8 @@ async def connect(dsn=None, *,
17921792
direct_tls=False,
17931793
connection_class=Connection,
17941794
record_class=protocol.Record,
1795-
server_settings=None):
1795+
server_settings=None,
1796+
target_session_attrs=None):
17961797
r"""A coroutine to establish a connection to a PostgreSQL server.
17971798
17981799
The connection parameters may be specified either as a connection
@@ -2003,6 +2004,22 @@ async def connect(dsn=None, *,
20032004
this connection object. Must be a subclass of
20042005
:class:`~asyncpg.Record`.
20052006
2007+
:param SessionAttribute target_session_attrs:
2008+
If specified, check that the host has the correct attribute.
2009+
Can be one of:
2010+
"any": the first successfully connected host
2011+
"primary": the host must NOT be in hot standby mode
2012+
"standby": the host must be in hot standby mode
2013+
"read-write": the host must allow writes
2014+
"read-only": the host most NOT allow writes
2015+
"prefer-standby": first try to find a standby host, but if
2016+
none of the listed hosts is a standby server,
2017+
return any of them.
2018+
2019+
If not specified will try to use PGTARGETSESSIONATTRS
2020+
from the environment.
2021+
Defaults to "any" if no value is set.
2022+
20062023
:return: A :class:`~asyncpg.connection.Connection` instance.
20072024
20082025
Example:
@@ -2109,6 +2126,7 @@ async def connect(dsn=None, *,
21092126
statement_cache_size=statement_cache_size,
21102127
max_cached_statement_lifetime=max_cached_statement_lifetime,
21112128
max_cacheable_statement_size=max_cacheable_statement_size,
2129+
target_session_attrs=target_session_attrs
21122130
)
21132131

21142132

asyncpg/exceptions/_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
1515
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
16-
'UnsupportedClientFeatureError')
16+
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
1717

1818

1919
def _is_asyncpg_class(cls):
@@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
244244
"""Unexpected condition in the handling of PostgreSQL protocol input."""
245245

246246

247+
class TargetServerAttributeNotMatched(InternalClientError):
248+
"""Could not find a host that satisfies the target attribute requirement"""
249+
250+
247251
class OutdatedSchemaCacheError(InternalClientError):
248252
"""A value decoding error caused by a schema change before row fetching."""
249253

0 commit comments

Comments
 (0)