Skip to content

Commit 8b39a33

Browse files
committed
Prefer SSL connections by default
Switch the default SSL mode from 'disabled' to 'prefer'. This matches libpq's behavior and is a sensible thing to do. Fixes: #654
1 parent 690048d commit 8b39a33

File tree

3 files changed

+49
-36
lines changed

3 files changed

+49
-36
lines changed

asyncpg/connect_utils.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
397397
if ssl is None:
398398
ssl = os.getenv('PGSSLMODE')
399399

400+
if ssl is None:
401+
ssl = 'prefer'
402+
400403
# ssl_is_advisory is only allowed to come from the sslmode parameter.
401404
ssl_is_advisory = None
402405
if isinstance(ssl, str):
@@ -435,14 +438,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
435438
if sslmode <= SSLMODES['require']:
436439
ssl.verify_mode = ssl_module.CERT_NONE
437440
ssl_is_advisory = sslmode <= SSLMODES['prefer']
438-
439-
if ssl:
440-
for addr in addrs:
441-
if isinstance(addr, str):
442-
# UNIX socket
443-
raise exceptions.InterfaceError(
444-
'`ssl` parameter can only be enabled for TCP addresses, '
445-
'got a UNIX socket path: {!r}'.format(addr))
441+
elif ssl is True:
442+
ssl = ssl_module.create_default_context()
446443

447444
if server_settings is not None and (
448445
not isinstance(server_settings, dict) or
@@ -542,9 +539,6 @@ def connection_lost(self, exc):
542539
async def _create_ssl_connection(protocol_factory, host, port, *,
543540
loop, ssl_context, ssl_is_advisory=False):
544541

545-
if ssl_context is True:
546-
ssl_context = ssl_module.create_default_context()
547-
548542
tr, pr = await loop.create_connection(
549543
lambda: TLSUpgradeProto(loop, host, port,
550544
ssl_context, ssl_is_advisory),
@@ -625,7 +619,6 @@ async def _connect_addr(
625619

626620
if isinstance(addr, str):
627621
# UNIX socket
628-
assert not params.ssl
629622
connector = loop.create_unix_connection(proto_factory, addr)
630623
elif params.ssl:
631624
connector = _create_ssl_connection(

asyncpg/connection.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,28 @@ async def connect(dsn=None, *,
18641864
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
18651865
require an SSL connection. If ``True``, a default SSL context
18661866
returned by `ssl.create_default_context() <create_default_context_>`_
1867-
will be used.
1867+
will be used. The value can also be one of the following strings:
1868+
1869+
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
1870+
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
1871+
if SSL connection fails
1872+
- ``'allow'`` - currently equivalent to ``'prefer'``
1873+
- ``'require'`` - only try an SSL connection. Certificate
1874+
verifiction errors are ignored
1875+
- ``'verify-ca'`` - only try an SSL connection, and verify
1876+
that the server certificate is issued by a trusted certificate
1877+
authority (CA)
1878+
- ``'verify-full'`` - only try an SSL connection, verify
1879+
that the server certificate is issued by a trusted CA and
1880+
that the requested server host name matches that in the
1881+
certificate.
1882+
1883+
The default is ``'prefer'``: try an SSL connection and fallback to
1884+
non-SSL connection if that fails.
1885+
1886+
.. note::
1887+
1888+
*ssl* is ignored for Unix domain socket communication.
18681889
18691890
:param dict server_settings:
18701891
An optional dict of server runtime parameters. Refer to
@@ -1921,6 +1942,9 @@ async def connect(dsn=None, *,
19211942
.. versionchanged:: 0.22.0
19221943
Added the *record_class* parameter.
19231944
1945+
.. versionchanged:: 0.22.0
1946+
The *ssl* argument now defaults to ``'prefer'``.
1947+
19241948
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
19251949
.. _create_default_context:
19261950
https://docs.python.org/3/library/ssl.html#ssl.create_default_context

tests/test_connect.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
318318
'result': ([('host', 123)], {
319319
'user': 'user',
320320
'password': 'passw',
321-
'database': 'testdb'})
321+
'database': 'testdb',
322+
'ssl': True,
323+
'ssl_is_advisory': True})
322324
},
323325

324326
{
@@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
384386
'user': 'user3',
385387
'password': '123123',
386388
'database': 'abcdef',
387-
'ssl': ssl.SSLContext,
389+
'ssl': True,
388390
'ssl_is_advisory': True})
389391
},
390392

@@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
461463
'user': 'me',
462464
'password': 'ask',
463465
'database': 'db',
464-
'ssl': ssl.SSLContext,
466+
'ssl': True,
465467
'ssl_is_advisory': False})
466468
},
467469

@@ -617,7 +619,7 @@ def run_testcase(self, testcase):
617619
password = testcase.get('password')
618620
passfile = testcase.get('passfile')
619621
database = testcase.get('database')
620-
ssl = testcase.get('ssl')
622+
sslmode = testcase.get('ssl')
621623
server_settings = testcase.get('server_settings')
622624

623625
expected = testcase.get('result')
@@ -640,21 +642,25 @@ def run_testcase(self, testcase):
640642

641643
addrs, params = connect_utils._parse_connect_dsn_and_args(
642644
dsn=dsn, host=host, port=port, user=user, password=password,
643-
passfile=passfile, database=database, ssl=ssl,
645+
passfile=passfile, database=database, ssl=sslmode,
644646
connect_timeout=None, server_settings=server_settings)
645647

646-
params = {k: v for k, v in params._asdict().items()
647-
if v is not None}
648+
params = {
649+
k: v for k, v in params._asdict().items() if v is not None
650+
}
651+
652+
if isinstance(params.get('ssl'), ssl.SSLContext):
653+
params['ssl'] = True
648654

649655
result = (addrs, params)
650656

651657
if expected is not None:
652-
for k, v in expected[1].items():
653-
# If `expected` contains a type, allow that to "match" any
654-
# instance of that type tyat `result` may contain. We need
655-
# this because different SSLContexts don't compare equal.
656-
if isinstance(v, type) and isinstance(result[1].get(k), v):
657-
result[1][k] = v
658+
if 'ssl' not in expected[1]:
659+
# Avoid the hassle of specifying the default SSL mode
660+
# unless explicitly tested for.
661+
params.pop('ssl', None)
662+
params.pop('ssl_is_advisory', None)
663+
658664
self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))
659665

660666
def test_test_connect_params_environ(self):
@@ -1063,16 +1069,6 @@ async def verify_fails(sslmode):
10631069
await verify_fails('verify-ca')
10641070
await verify_fails('verify-full')
10651071

1066-
async def test_connection_ssl_unix(self):
1067-
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1068-
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)
1069-
1070-
with self.assertRaisesRegex(asyncpg.InterfaceError,
1071-
'can only be enabled for TCP addresses'):
1072-
await self.connect(
1073-
host='/tmp',
1074-
ssl=ssl_context)
1075-
10761072
async def test_connection_implicit_host(self):
10771073
conn_spec = self.get_connection_spec()
10781074
con = await asyncpg.connect(

0 commit comments

Comments
 (0)