13
13
import os
14
14
import pathlib
15
15
import platform
16
+ import random
16
17
import re
17
18
import socket
18
19
import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
56
57
'direct_tls' ,
57
58
'connect_timeout' ,
58
59
'server_settings' ,
60
+ 'target_session_attrs' ,
59
61
])
60
62
61
63
@@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
260
262
261
263
def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
262
264
password , passfile , database , ssl ,
263
- direct_tls , connect_timeout , server_settings ):
265
+ direct_tls , connect_timeout , server_settings ,
266
+ target_session_attrs ):
264
267
# `auth_hosts` is the version of host information for the purposes
265
268
# of reading the pgpass file.
266
269
auth_hosts = None
@@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
607
610
'server_settings is expected to be None or '
608
611
'a Dict[str, str]' )
609
612
613
+ if target_session_attrs is None :
614
+
615
+ target_session_attrs = os .getenv (
616
+ "PGTARGETSESSIONATTRS" , SessionAttribute .any
617
+ )
618
+ try :
619
+
620
+ target_session_attrs = SessionAttribute (target_session_attrs )
621
+ except ValueError as exc :
622
+ raise exceptions .InterfaceError (
623
+ "target_session_attrs is expected to be one of "
624
+ "{!r}"
625
+ ", got {!r}" .format (
626
+ SessionAttribute .__members__ .values , target_session_attrs
627
+ )
628
+ ) from exc
629
+
610
630
params = _ConnectionParameters (
611
631
user = user , password = password , database = database , ssl = ssl ,
612
632
sslmode = sslmode , direct_tls = direct_tls ,
613
- connect_timeout = connect_timeout , server_settings = server_settings )
633
+ connect_timeout = connect_timeout , server_settings = server_settings ,
634
+ target_session_attrs = target_session_attrs )
614
635
615
636
return addrs , params
616
637
@@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
620
641
statement_cache_size ,
621
642
max_cached_statement_lifetime ,
622
643
max_cacheable_statement_size ,
623
- ssl , direct_tls , server_settings ):
624
-
644
+ ssl , direct_tls , server_settings ,
645
+ target_session_attrs ):
625
646
local_vars = locals ()
626
647
for var_name in {'max_cacheable_statement_size' ,
627
648
'max_cached_statement_lifetime' ,
@@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
649
670
dsn = dsn , host = host , port = port , user = user ,
650
671
password = password , passfile = passfile , ssl = ssl ,
651
672
direct_tls = direct_tls , database = database ,
652
- connect_timeout = timeout , server_settings = server_settings )
673
+ connect_timeout = timeout , server_settings = server_settings ,
674
+ target_session_attrs = target_session_attrs )
653
675
654
676
config = _ClientConfiguration (
655
677
command_timeout = command_timeout ,
@@ -882,18 +904,84 @@ async def __connect_addr(
882
904
return con
883
905
884
906
907
+ class SessionAttribute (str , enum .Enum ):
908
+ any = 'any'
909
+ primary = 'primary'
910
+ standby = 'standby'
911
+ prefer_standby = 'prefer-standby'
912
+ read_write = "read-write"
913
+ read_only = "read-only"
914
+
915
+
916
+ def _accept_in_hot_standby (should_be_in_hot_standby : bool ):
917
+ """
918
+ If the server didn't report "in_hot_standby" at startup, we must determine
919
+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
920
+ If the server allows a connection and states it is in recovery it must
921
+ be a replica/standby server.
922
+ """
923
+ async def can_be_used (connection ):
924
+ settings = connection .get_settings ()
925
+ hot_standby_status = getattr (settings , 'in_hot_standby' , None )
926
+ if hot_standby_status is not None :
927
+ is_in_hot_standby = hot_standby_status == 'on'
928
+ else :
929
+ is_in_hot_standby = await connection .fetchval (
930
+ "SELECT pg_catalog.pg_is_in_recovery()"
931
+ )
932
+ return is_in_hot_standby == should_be_in_hot_standby
933
+
934
+ return can_be_used
935
+
936
+
937
+ def _accept_read_only (should_be_read_only : bool ):
938
+ """
939
+ Verify the server has not set default_transaction_read_only=True
940
+ """
941
+ async def can_be_used (connection ):
942
+ settings = connection .get_settings ()
943
+ is_readonly = getattr (settings , 'default_transaction_read_only' , 'off' )
944
+
945
+ if is_readonly == "on" :
946
+ return should_be_read_only
947
+
948
+ return await _accept_in_hot_standby (should_be_read_only )(connection )
949
+ return can_be_used
950
+
951
+
952
+ async def _accept_any (_ ):
953
+ return True
954
+
955
+
956
+ target_attrs_check = {
957
+ SessionAttribute .any : _accept_any ,
958
+ SessionAttribute .primary : _accept_in_hot_standby (False ),
959
+ SessionAttribute .standby : _accept_in_hot_standby (True ),
960
+ SessionAttribute .prefer_standby : _accept_in_hot_standby (True ),
961
+ SessionAttribute .read_write : _accept_read_only (False ),
962
+ SessionAttribute .read_only : _accept_read_only (True ),
963
+ }
964
+
965
+
966
+ async def _can_use_connection (connection , attr : SessionAttribute ):
967
+ can_use = target_attrs_check [attr ]
968
+ return await can_use (connection )
969
+
970
+
885
971
async def _connect (* , loop , timeout , connection_class , record_class , ** kwargs ):
886
972
if loop is None :
887
973
loop = asyncio .get_event_loop ()
888
974
889
975
addrs , params , config = _parse_connect_arguments (timeout = timeout , ** kwargs )
976
+ target_attr = params .target_session_attrs
890
977
978
+ candidates = []
979
+ chosen_connection = None
891
980
last_error = None
892
- addr = None
893
981
for addr in addrs :
894
982
before = time .monotonic ()
895
983
try :
896
- return await _connect_addr (
984
+ conn = await _connect_addr (
897
985
addr = addr ,
898
986
loop = loop ,
899
987
timeout = timeout ,
@@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
902
990
connection_class = connection_class ,
903
991
record_class = record_class ,
904
992
)
993
+ candidates .append (conn )
994
+ if await _can_use_connection (conn , target_attr ):
995
+ chosen_connection = conn
996
+ break
905
997
except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
906
998
last_error = ex
907
999
finally :
908
1000
timeout -= time .monotonic () - before
1001
+ else :
1002
+ if target_attr == SessionAttribute .prefer_standby and candidates :
1003
+ chosen_connection = random .choice (candidates )
1004
+
1005
+ await asyncio .gather (
1006
+ (c .close () for c in candidates if c is not chosen_connection ),
1007
+ return_exceptions = True
1008
+ )
1009
+
1010
+ if chosen_connection :
1011
+ return chosen_connection
909
1012
910
- raise last_error
1013
+ raise last_error or exceptions .TargetServerAttributeNotMatched (
1014
+ 'None of the hosts match the target attribute requirement '
1015
+ '{!r}' .format (target_attr )
1016
+ )
911
1017
912
1018
913
1019
async def _cancel (* , loop , addr , params : _ConnectionParameters ,
0 commit comments