Skip to content

Commit bf6033c

Browse files
author
rony batista
committed
Add target session attribute connection param
1 parent 9825bbb commit bf6033c

File tree

6 files changed

+261
-58
lines changed

6 files changed

+261
-58
lines changed

asyncpg/_testbase/__init__.py

+89
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,92 @@ def tearDown(self):
435435
self.con = None
436436
finally:
437437
super().tearDown()
438+
439+
440+
class HotStandbyTestCase(ClusterTestCase):
441+
@classmethod
442+
def setup_cluster(cls):
443+
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
444+
cls.start_cluster(
445+
cls.master_cluster,
446+
server_settings={
447+
'max_wal_senders': 10,
448+
'wal_level': 'hot_standby'
449+
}
450+
)
451+
452+
con = None
453+
454+
try:
455+
con = cls.loop.run_until_complete(
456+
cls.master_cluster.connect(
457+
database='postgres', user='postgres', loop=cls.loop))
458+
459+
cls.loop.run_until_complete(
460+
con.execute('''
461+
CREATE ROLE replication WITH LOGIN REPLICATION
462+
'''))
463+
464+
cls.master_cluster.trust_local_replication_by('replication')
465+
466+
conn_spec = cls.master_cluster.get_connection_spec()
467+
468+
cls.standby_cluster = cls.new_cluster(
469+
pg_cluster.HotStandbyCluster,
470+
cluster_kwargs={
471+
'master': conn_spec,
472+
'replication_user': 'replication'
473+
}
474+
)
475+
cls.start_cluster(
476+
cls.standby_cluster,
477+
server_settings={
478+
'hot_standby': True
479+
}
480+
)
481+
482+
finally:
483+
if con is not None:
484+
cls.loop.run_until_complete(con.close())
485+
486+
@classmethod
487+
def get_cluster_connection_spec(cls, cluster, kwargs={}):
488+
conn_spec = cluster.get_connection_spec()
489+
if kwargs.get('dsn'):
490+
conn_spec.pop('host')
491+
conn_spec.update(kwargs)
492+
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
493+
if 'database' not in conn_spec:
494+
conn_spec['database'] = 'postgres'
495+
if 'user' not in conn_spec:
496+
conn_spec['user'] = 'postgres'
497+
return conn_spec
498+
499+
@classmethod
500+
def get_connection_spec(cls, kwargs={}):
501+
primary_spec = cls.get_cluster_connection_spec(
502+
cls.master_cluster, kwargs
503+
)
504+
standby_spec = cls.get_cluster_connection_spec(
505+
cls.standby_cluster, kwargs
506+
)
507+
return {
508+
'host': [primary_spec['host'], standby_spec['host']],
509+
'port': [primary_spec['port'], standby_spec['port']],
510+
'database': primary_spec['database'],
511+
'user': primary_spec['user'],
512+
**kwargs
513+
}
514+
515+
@classmethod
516+
def connect_primary(cls, **kwargs):
517+
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
518+
return pg_connection.connect(**conn_spec, loop=cls.loop)
519+
520+
@classmethod
521+
def connect_standby(cls, **kwargs):
522+
conn_spec = cls.get_cluster_connection_spec(
523+
cls.standby_cluster,
524+
kwargs
525+
)
526+
return pg_connection.connect(**conn_spec, loop=cls.loop)

asyncpg/connect_utils.py

+77-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import pathlib
1515
import platform
16+
import random
1617
import re
1718
import socket
1819
import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
5657
'direct_tls',
5758
'connect_timeout',
5859
'server_settings',
60+
'target_session_attribute',
5961
])
6062

6163

@@ -259,7 +261,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
259261

260262
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
261263
password, passfile, database, ssl,
262-
direct_tls, connect_timeout, server_settings):
264+
direct_tls, connect_timeout, server_settings,
265+
target_session_attribute):
263266
# `auth_hosts` is the version of host information for the purposes
264267
# of reading the pgpass file.
265268
auth_hosts = None
@@ -603,7 +606,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
603606
params = _ConnectionParameters(
604607
user=user, password=password, database=database, ssl=ssl,
605608
sslmode=sslmode, direct_tls=direct_tls,
606-
connect_timeout=connect_timeout, server_settings=server_settings)
609+
connect_timeout=connect_timeout, server_settings=server_settings,
610+
target_session_attribute=target_session_attribute)
607611

608612
return addrs, params
609613

@@ -613,8 +617,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
613617
statement_cache_size,
614618
max_cached_statement_lifetime,
615619
max_cacheable_statement_size,
616-
ssl, direct_tls, server_settings):
617-
620+
ssl, direct_tls, server_settings,
621+
target_session_attribute):
618622
local_vars = locals()
619623
for var_name in {'max_cacheable_statement_size',
620624
'max_cached_statement_lifetime',
@@ -642,7 +646,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
642646
dsn=dsn, host=host, port=port, user=user,
643647
password=password, passfile=passfile, ssl=ssl,
644648
direct_tls=direct_tls, database=database,
645-
connect_timeout=timeout, server_settings=server_settings)
649+
connect_timeout=timeout, server_settings=server_settings,
650+
target_session_attribute=target_session_attribute)
646651

647652
config = _ClientConfiguration(
648653
command_timeout=command_timeout,
@@ -875,18 +880,64 @@ async def __connect_addr(
875880
return con
876881

877882

883+
class SessionAttribute(str, enum.Enum):
884+
any = 'any'
885+
primary = 'primary'
886+
standby = 'standby'
887+
prefer_standby = 'prefer-standby'
888+
889+
890+
def _accept_in_hot_standby(should_be_in_hot_standby: bool):
891+
"""
892+
If the server didn't report "in_hot_standby" at startup, we must determine
893+
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
894+
"""
895+
async def can_be_used(connection):
896+
settings = connection.get_settings()
897+
hot_standby_status = getattr(settings, 'in_hot_standby', None)
898+
if hot_standby_status is not None:
899+
is_in_hot_standby = hot_standby_status == 'on'
900+
else:
901+
is_in_hot_standby = await connection.fetchval(
902+
"SELECT pg_catalog.pg_is_in_recovery()"
903+
)
904+
905+
return is_in_hot_standby == should_be_in_hot_standby
906+
907+
return can_be_used
908+
909+
910+
async def _accept_any(_):
911+
return True
912+
913+
914+
target_attrs_check = {
915+
SessionAttribute.any: _accept_any,
916+
SessionAttribute.primary: _accept_in_hot_standby(False),
917+
SessionAttribute.standby: _accept_in_hot_standby(True),
918+
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
919+
}
920+
921+
922+
async def _can_use_connection(connection, attr: SessionAttribute):
923+
can_use = target_attrs_check[attr]
924+
return await can_use(connection)
925+
926+
878927
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
879928
if loop is None:
880929
loop = asyncio.get_event_loop()
881930

882931
addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
932+
target_attr = params.target_session_attribute
883933

934+
candidates = []
935+
chosen_connection = None
884936
last_error = None
885-
addr = None
886937
for addr in addrs:
887938
before = time.monotonic()
888939
try:
889-
return await _connect_addr(
940+
conn = await _connect_addr(
890941
addr=addr,
891942
loop=loop,
892943
timeout=timeout,
@@ -895,12 +946,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
895946
connection_class=connection_class,
896947
record_class=record_class,
897948
)
949+
candidates.append(conn)
950+
if await _can_use_connection(conn, target_attr):
951+
chosen_connection = conn
952+
break
898953
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
899954
last_error = ex
900955
finally:
901956
timeout -= time.monotonic() - before
957+
else:
958+
if target_attr == SessionAttribute.prefer_standby and candidates:
959+
chosen_connection = random.choice(candidates)
960+
961+
await asyncio.gather(
962+
(c.close() for c in candidates if c is not chosen_connection),
963+
return_exceptions=True
964+
)
965+
966+
if chosen_connection:
967+
return chosen_connection
902968

903-
raise last_error
969+
raise last_error or exceptions.TargetServerAttributeNotMatched(
970+
'None of the hosts match the target attribute requirement '
971+
'{!r}'.format(target_attr)
972+
)
904973

905974

906975
async def _cancel(*, loop, addr, params: _ConnectionParameters,

asyncpg/connection.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from . import serverversion
3131
from . import transaction
3232
from . import utils
33+
from .connect_utils import SessionAttribute
3334

3435

3536
class ConnectionMeta(type):
@@ -1792,7 +1793,8 @@ async def connect(dsn=None, *,
17921793
direct_tls=False,
17931794
connection_class=Connection,
17941795
record_class=protocol.Record,
1795-
server_settings=None):
1796+
server_settings=None,
1797+
target_session_attribute=SessionAttribute.any):
17961798
r"""A coroutine to establish a connection to a PostgreSQL server.
17971799
17981800
The connection parameters may be specified either as a connection
@@ -2003,6 +2005,16 @@ async def connect(dsn=None, *,
20032005
this connection object. Must be a subclass of
20042006
:class:`~asyncpg.Record`.
20052007
2008+
:param SessionAttribute target_session_attribute:
2009+
If specified, check that the host has the correct attribute.
2010+
Can be one of:
2011+
"any": the first successfully connected host
2012+
"primary": the host must NOT be in hot standby mode
2013+
"standby": the host must be in hot standby mode
2014+
"prefer-standby": first try to find a standby host, but if
2015+
none of the listed hosts is a standby server,
2016+
return any of them.
2017+
20062018
:return: A :class:`~asyncpg.connection.Connection` instance.
20072019
20082020
Example:
@@ -2087,6 +2099,15 @@ async def connect(dsn=None, *,
20872099
if record_class is not protocol.Record:
20882100
_check_record_class(record_class)
20892101

2102+
try:
2103+
target_session_attribute = SessionAttribute(target_session_attribute)
2104+
except ValueError as exc:
2105+
raise exceptions.InterfaceError(
2106+
"target_session_attribute is expected to be one of "
2107+
"'any', 'primary', 'standby' or 'prefer-standby'"
2108+
", got {!r}".format(target_session_attribute)
2109+
) from exc
2110+
20902111
if loop is None:
20912112
loop = asyncio.get_event_loop()
20922113

@@ -2109,6 +2130,7 @@ async def connect(dsn=None, *,
21092130
statement_cache_size=statement_cache_size,
21102131
max_cached_statement_lifetime=max_cached_statement_lifetime,
21112132
max_cacheable_statement_size=max_cacheable_statement_size,
2133+
target_session_attribute=target_session_attribute
21122134
)
21132135

21142136

asyncpg/exceptions/_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
1515
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
16-
'UnsupportedClientFeatureError')
16+
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
1717

1818

1919
def _is_asyncpg_class(cls):
@@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
244244
"""Unexpected condition in the handling of PostgreSQL protocol input."""
245245

246246

247+
class TargetServerAttributeNotMatched(InternalClientError):
248+
"""Could not find a host that satisfies the target attribute requirement"""
249+
250+
247251
class OutdatedSchemaCacheError(InternalClientError):
248252
"""A value decoding error caused by a schema change before row fetching."""
249253

0 commit comments

Comments
 (0)