Skip to content

Commit f2a937d

Browse files
Support direct TLS connections (i.e. no STARTTLS) (#923)
Adding direct_tls param that when equal to True alongside the ssl param being set to a ssl.SSLContext will result in a direct SSL connection being made, skipping STARTTLS implementation. Closes #906
1 parent bd19262 commit f2a937d

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

asyncpg/connect_utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def parse(cls, sslmode):
5353
'database',
5454
'ssl',
5555
'sslmode',
56+
'direct_tls',
5657
'connect_timeout',
5758
'server_settings',
5859
])
@@ -258,7 +259,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
258259

259260
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
260261
password, passfile, database, ssl,
261-
connect_timeout, server_settings):
262+
direct_tls, connect_timeout, server_settings):
262263
# `auth_hosts` is the version of host information for the purposes
263264
# of reading the pgpass file.
264265
auth_hosts = None
@@ -601,8 +602,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
601602

602603
params = _ConnectionParameters(
603604
user=user, password=password, database=database, ssl=ssl,
604-
sslmode=sslmode, connect_timeout=connect_timeout,
605-
server_settings=server_settings)
605+
sslmode=sslmode, direct_tls=direct_tls,
606+
connect_timeout=connect_timeout, server_settings=server_settings)
606607

607608
return addrs, params
608609

@@ -612,7 +613,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
612613
statement_cache_size,
613614
max_cached_statement_lifetime,
614615
max_cacheable_statement_size,
615-
ssl, server_settings):
616+
ssl, direct_tls, server_settings):
616617

617618
local_vars = locals()
618619
for var_name in {'max_cacheable_statement_size',
@@ -640,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
640641
addrs, params = _parse_connect_dsn_and_args(
641642
dsn=dsn, host=host, port=port, user=user,
642643
password=password, passfile=passfile, ssl=ssl,
643-
database=database, connect_timeout=timeout,
644-
server_settings=server_settings)
644+
direct_tls=direct_tls, database=database,
645+
connect_timeout=timeout, server_settings=server_settings)
645646

646647
config = _ClientConfiguration(
647648
command_timeout=command_timeout,
@@ -812,6 +813,14 @@ async def __connect_addr(
812813
if isinstance(addr, str):
813814
# UNIX socket
814815
connector = loop.create_unix_connection(proto_factory, addr)
816+
817+
elif params.ssl and params.direct_tls:
818+
# if ssl and direct_tls are given, skip STARTTLS and perform direct
819+
# SSL connection
820+
connector = loop.create_connection(
821+
proto_factory, *addr, ssl=params.ssl
822+
)
823+
815824
elif params.ssl:
816825
connector = _create_ssl_connection(
817826
proto_factory, *addr, loop=loop, ssl_context=params.ssl,

asyncpg/connection.py

+6
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,7 @@ async def connect(dsn=None, *,
17891789
max_cacheable_statement_size=1024 * 15,
17901790
command_timeout=None,
17911791
ssl=None,
1792+
direct_tls=False,
17921793
connection_class=Connection,
17931794
record_class=protocol.Record,
17941795
server_settings=None):
@@ -1984,6 +1985,10 @@ async def connect(dsn=None, *,
19841985
... await con.close()
19851986
>>> asyncio.run(run())
19861987
1988+
:param bool direct_tls:
1989+
Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct
1990+
SSL connection. Must be used alongside ``ssl`` param.
1991+
19871992
:param dict server_settings:
19881993
An optional dict of server runtime parameters. Refer to
19891994
PostgreSQL documentation for
@@ -2094,6 +2099,7 @@ async def connect(dsn=None, *,
20942099
password=password,
20952100
passfile=passfile,
20962101
ssl=ssl,
2102+
direct_tls=direct_tls,
20972103
database=database,
20982104
server_settings=server_settings,
20992105
command_timeout=command_timeout,

tests/test_connect.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,8 @@ def run_testcase(self, testcase):
811811
addrs, params = connect_utils._parse_connect_dsn_and_args(
812812
dsn=dsn, host=host, port=port, user=user, password=password,
813813
passfile=passfile, database=database, ssl=sslmode,
814-
connect_timeout=None, server_settings=server_settings)
814+
direct_tls=False, connect_timeout=None,
815+
server_settings=server_settings)
815816

816817
params = {
817818
k: v for k, v in params._asdict().items()
@@ -829,6 +830,10 @@ def run_testcase(self, testcase):
829830
# unless explicitly tested for.
830831
params.pop('ssl', None)
831832
params.pop('sslmode', None)
833+
if 'direct_tls' not in expected[1]:
834+
# Avoid the hassle of specifying direct_tls
835+
# unless explicitly tested for
836+
params.pop('direct_tls', False)
832837

833838
self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))
834839

0 commit comments

Comments
 (0)