@@ -53,6 +53,7 @@ def parse(cls, sslmode):
53
53
'database' ,
54
54
'ssl' ,
55
55
'sslmode' ,
56
+ 'direct_tls' ,
56
57
'connect_timeout' ,
57
58
'server_settings' ,
58
59
])
@@ -258,7 +259,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
258
259
259
260
def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
260
261
password , passfile , database , ssl ,
261
- connect_timeout , server_settings ):
262
+ direct_tls , connect_timeout , server_settings ):
262
263
# `auth_hosts` is the version of host information for the purposes
263
264
# of reading the pgpass file.
264
265
auth_hosts = None
@@ -601,8 +602,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
601
602
602
603
params = _ConnectionParameters (
603
604
user = user , password = password , database = database , ssl = ssl ,
604
- sslmode = sslmode , connect_timeout = connect_timeout ,
605
- server_settings = server_settings )
605
+ sslmode = sslmode , direct_tls = direct_tls ,
606
+ connect_timeout = connect_timeout , server_settings = server_settings )
606
607
607
608
return addrs , params
608
609
@@ -612,7 +613,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
612
613
statement_cache_size ,
613
614
max_cached_statement_lifetime ,
614
615
max_cacheable_statement_size ,
615
- ssl , server_settings ):
616
+ ssl , direct_tls , server_settings ):
616
617
617
618
local_vars = locals ()
618
619
for var_name in {'max_cacheable_statement_size' ,
@@ -640,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
640
641
addrs , params = _parse_connect_dsn_and_args (
641
642
dsn = dsn , host = host , port = port , user = user ,
642
643
password = password , passfile = passfile , ssl = ssl ,
643
- database = database , connect_timeout = timeout ,
644
- server_settings = server_settings )
644
+ direct_tls = direct_tls , database = database ,
645
+ connect_timeout = timeout , server_settings = server_settings )
645
646
646
647
config = _ClientConfiguration (
647
648
command_timeout = command_timeout ,
@@ -812,6 +813,14 @@ async def __connect_addr(
812
813
if isinstance (addr , str ):
813
814
# UNIX socket
814
815
connector = loop .create_unix_connection (proto_factory , addr )
816
+
817
+ elif params .ssl and params .direct_tls :
818
+ # if ssl and direct_tls are given, skip STARTTLS and perform direct
819
+ # SSL connection
820
+ connector = loop .create_connection (
821
+ proto_factory , * addr , ssl = params .ssl
822
+ )
823
+
815
824
elif params .ssl :
816
825
connector = _create_ssl_connection (
817
826
proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
0 commit comments