13
13
import os
14
14
import pathlib
15
15
import platform
16
+ import random
16
17
import re
17
18
import socket
18
19
import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
56
57
'direct_tls' ,
57
58
'connect_timeout' ,
58
59
'server_settings' ,
60
+ 'target_session_attribute' ,
59
61
])
60
62
61
63
@@ -259,7 +261,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
259
261
260
262
def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
261
263
password , passfile , database , ssl ,
262
- direct_tls , connect_timeout , server_settings ):
264
+ direct_tls , connect_timeout , server_settings ,
265
+ target_session_attribute ):
263
266
# `auth_hosts` is the version of host information for the purposes
264
267
# of reading the pgpass file.
265
268
auth_hosts = None
@@ -603,7 +606,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
603
606
params = _ConnectionParameters (
604
607
user = user , password = password , database = database , ssl = ssl ,
605
608
sslmode = sslmode , direct_tls = direct_tls ,
606
- connect_timeout = connect_timeout , server_settings = server_settings )
609
+ connect_timeout = connect_timeout , server_settings = server_settings ,
610
+ target_session_attribute = target_session_attribute )
607
611
608
612
return addrs , params
609
613
@@ -613,8 +617,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
613
617
statement_cache_size ,
614
618
max_cached_statement_lifetime ,
615
619
max_cacheable_statement_size ,
616
- ssl , direct_tls , server_settings ):
617
-
620
+ ssl , direct_tls , server_settings ,
621
+ target_session_attribute ):
618
622
local_vars = locals ()
619
623
for var_name in {'max_cacheable_statement_size' ,
620
624
'max_cached_statement_lifetime' ,
@@ -642,7 +646,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
642
646
dsn = dsn , host = host , port = port , user = user ,
643
647
password = password , passfile = passfile , ssl = ssl ,
644
648
direct_tls = direct_tls , database = database ,
645
- connect_timeout = timeout , server_settings = server_settings )
649
+ connect_timeout = timeout , server_settings = server_settings ,
650
+ target_session_attribute = target_session_attribute )
646
651
647
652
config = _ClientConfiguration (
648
653
command_timeout = command_timeout ,
@@ -875,18 +880,64 @@ async def __connect_addr(
875
880
return con
876
881
877
882
883
+ class SessionAttribute (str , enum .Enum ):
884
+ any = 'any'
885
+ primary = 'primary'
886
+ standby = 'standby'
887
+ prefer_standby = 'prefer-standby'
888
+
889
+
890
+ def _accept_in_hot_standby (should_be_in_hot_standby : bool ):
891
+ """
892
+ If the server didn't report "in_hot_standby" at startup, we must determine
893
+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
894
+ """
895
+ async def can_be_used (connection ):
896
+ settings = connection .get_settings ()
897
+ hot_standby_status = getattr (settings , 'in_hot_standby' , None )
898
+ if hot_standby_status is not None :
899
+ is_in_hot_standby = hot_standby_status == 'on'
900
+ else :
901
+ is_in_hot_standby = await connection .fetchval (
902
+ "SELECT pg_catalog.pg_is_in_recovery()"
903
+ )
904
+
905
+ return is_in_hot_standby == should_be_in_hot_standby
906
+
907
+ return can_be_used
908
+
909
+
910
+ async def _accept_any (_ ):
911
+ return True
912
+
913
+
914
+ target_attrs_check = {
915
+ SessionAttribute .any : _accept_any ,
916
+ SessionAttribute .primary : _accept_in_hot_standby (False ),
917
+ SessionAttribute .standby : _accept_in_hot_standby (True ),
918
+ SessionAttribute .prefer_standby : _accept_in_hot_standby (True ),
919
+ }
920
+
921
+
922
+ async def _can_use_connection (connection , attr : SessionAttribute ):
923
+ can_use = target_attrs_check [attr ]
924
+ return await can_use (connection )
925
+
926
+
878
927
async def _connect (* , loop , timeout , connection_class , record_class , ** kwargs ):
879
928
if loop is None :
880
929
loop = asyncio .get_event_loop ()
881
930
882
931
addrs , params , config = _parse_connect_arguments (timeout = timeout , ** kwargs )
932
+ target_attr = params .target_session_attribute
883
933
934
+ candidates = []
935
+ chosen_connection = None
884
936
last_error = None
885
- addr = None
886
937
for addr in addrs :
887
938
before = time .monotonic ()
888
939
try :
889
- return await _connect_addr (
940
+ conn = await _connect_addr (
890
941
addr = addr ,
891
942
loop = loop ,
892
943
timeout = timeout ,
@@ -895,12 +946,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
895
946
connection_class = connection_class ,
896
947
record_class = record_class ,
897
948
)
949
+ candidates .append (conn )
950
+ if await _can_use_connection (conn , target_attr ):
951
+ chosen_connection = conn
952
+ break
898
953
except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
899
954
last_error = ex
900
955
finally :
901
956
timeout -= time .monotonic () - before
957
+ else :
958
+ if target_attr == SessionAttribute .prefer_standby and candidates :
959
+ chosen_connection = random .choice (candidates )
960
+
961
+ await asyncio .gather (
962
+ (c .close () for c in candidates if c is not chosen_connection ),
963
+ return_exceptions = True
964
+ )
965
+
966
+ if chosen_connection :
967
+ return chosen_connection
902
968
903
- raise last_error
969
+ raise last_error or exceptions .TargetServerAttributeNotMatched (
970
+ 'None of the hosts match the target attribute requirement '
971
+ '{!r}' .format (target_attr )
972
+ )
904
973
905
974
906
975
async def _cancel (* , loop , addr , params : _ConnectionParameters ,
0 commit comments