Skip to content

Commit 5836a8f

Browse files
committed
Add support for SSL connections.
Closes: #25.
1 parent c550388 commit 5836a8f

File tree

3 files changed

+198
-50
lines changed

3 files changed

+198
-50
lines changed

asyncpg/_testbase.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,14 @@ def get_server_settings(cls):
178178
'log_connections': 'on'
179179
}
180180

181+
@classmethod
182+
def setup_cluster(cls):
183+
cls.cluster = _start_default_cluster(cls.get_server_settings())
184+
181185
@classmethod
182186
def setUpClass(cls):
183187
super().setUpClass()
184-
cls.cluster = _start_default_cluster(cls.get_server_settings())
188+
cls.setup_cluster()
185189

186190
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
187191
conn_spec = self.cluster.get_connection_spec()

asyncpg/connection.py

+121-32
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ class Connection(metaclass=ConnectionMeta):
4242
'_stmt_cache', '_stmts_to_close',
4343
'_addr', '_opts', '_command_timeout', '_listeners',
4444
'_server_version', '_server_caps', '_intro_query',
45-
'_reset_query', '_proxy', '_stmt_exclusive_section')
45+
'_reset_query', '_proxy', '_stmt_exclusive_section',
46+
'_ssl_context')
4647

4748
def __init__(self, protocol, transport, loop, addr, opts, *,
4849
statement_cache_size, command_timeout,
49-
max_cached_statement_lifetime):
50+
max_cached_statement_lifetime, ssl_context):
5051
self._protocol = protocol
5152
self._transport = transport
5253
self._loop = loop
@@ -58,6 +59,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5859

5960
self._addr = addr
6061
self._opts = opts
62+
self._ssl_context = ssl_context
6163

6264
self._stmt_cache = _StatementCache(
6365
loop=loop,
@@ -521,12 +523,24 @@ async def cancel():
521523
r, w = await asyncio.open_unix_connection(
522524
self._addr, loop=self._loop)
523525
else:
524-
r, w = await asyncio.open_connection(
525-
*self._addr, loop=self._loop)
526-
527-
sock = w.transport.get_extra_info('socket')
528-
sock.setsockopt(socket.IPPROTO_TCP,
529-
socket.TCP_NODELAY, 1)
526+
if self._ssl_context:
527+
sock = await _get_ssl_ready_socket(
528+
*self._addr, loop=self._loop)
529+
530+
try:
531+
r, w = await asyncio.open_connection(
532+
sock=sock,
533+
loop=self._loop,
534+
ssl=self._ssl_context,
535+
server_hostname=self._addr[0])
536+
except Exception:
537+
sock.close()
538+
raise
539+
540+
else:
541+
r, w = await asyncio.open_connection(
542+
*self._addr, loop=self._loop)
543+
_set_nodelay(_get_socket(w.transport))
530544

531545
# Pack CancelRequest message
532546
msg = struct.pack('!llll', 16, 80877102,
@@ -708,9 +722,10 @@ async def connect(dsn=None, *,
708722
statement_cache_size=100,
709723
max_cached_statement_lifetime=300,
710724
command_timeout=None,
725+
ssl=None,
711726
__connection_class__=Connection,
712727
**opts):
713-
"""A coroutine to establish a connection to a PostgreSQL server.
728+
r"""A coroutine to establish a connection to a PostgreSQL server.
714729
715730
Returns a new :class:`~asyncpg.connection.Connection` object.
716731
@@ -761,6 +776,12 @@ async def connect(dsn=None, *,
761776
the default timeout for operations on this connection
762777
(the default is no timeout).
763778
779+
:param ssl:
780+
pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
781+
require an SSL connection. If ``True``, a default SSL context
782+
returned by `ssl.create_default_context() <create_default_context_>`_
783+
will be used.
784+
764785
:return: A :class:`~asyncpg.connection.Connection` instance.
765786
766787
Example:
@@ -778,42 +799,51 @@ async def connect(dsn=None, *,
778799
779800
.. versionchanged:: 0.10.0
780801
Added ``max_cached_statement_use_count`` parameter.
802+
803+
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
804+
.. _create_default_context: https://docs.python.org/3/library/ssl.html#\
805+
ssl.create_default_context
781806
"""
782807
if loop is None:
783808
loop = asyncio.get_event_loop()
784809

785-
host, port, opts = _parse_connect_params(
810+
addrs, opts = _parse_connect_params(
786811
dsn=dsn, host=host, port=port, user=user, password=password,
787812
database=database, opts=opts)
788813

789-
last_ex = None
814+
if ssl:
815+
for addr in addrs:
816+
if isinstance(addr, str):
817+
# UNIX socket
818+
raise exceptions.InterfaceError(
819+
'`ssl` parameter can only be enabled for TCP addresses, '
820+
'got a UNIX socket path: {!r}'.format(addr))
821+
822+
last_error = None
790823
addr = None
791-
for h in host:
824+
for addr in addrs:
792825
connected = _create_future(loop)
793-
unix = h.startswith('/')
794-
795-
if unix:
796-
# UNIX socket name
797-
addr = h
798-
if '.s.PGSQL.' not in addr:
799-
addr = os.path.join(addr, '.s.PGSQL.{}'.format(port))
800-
conn = loop.create_unix_connection(
801-
lambda: protocol.Protocol(addr, connected, opts, loop),
802-
addr)
826+
proto_factory = lambda: protocol.Protocol(addr, connected, opts, loop)
827+
828+
if isinstance(addr, str):
829+
# UNIX socket
830+
assert ssl is None
831+
connector = loop.create_unix_connection(proto_factory, addr)
832+
elif ssl:
833+
connector = _create_ssl_connection(
834+
proto_factory, *addr, loop=loop, ssl_context=ssl)
803835
else:
804-
addr = (h, port)
805-
conn = loop.create_connection(
806-
lambda: protocol.Protocol(addr, connected, opts, loop),
807-
h, port)
836+
connector = loop.create_connection(proto_factory, *addr)
808837

809838
try:
810-
tr, pr = await asyncio.wait_for(conn, timeout=timeout, loop=loop)
811-
except (OSError, asyncio.TimeoutError) as ex:
812-
last_ex = ex
839+
tr, pr = await asyncio.wait_for(
840+
connector, timeout=timeout, loop=loop)
841+
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
842+
last_error = ex
813843
else:
814844
break
815845
else:
816-
raise last_ex
846+
raise last_error
817847

818848
try:
819849
await connected
@@ -825,12 +855,60 @@ async def connect(dsn=None, *,
825855
pr, tr, loop, addr, opts,
826856
statement_cache_size=statement_cache_size,
827857
max_cached_statement_lifetime=max_cached_statement_lifetime,
828-
command_timeout=command_timeout)
858+
command_timeout=command_timeout, ssl_context=ssl)
829859

830860
pr.set_connection(con)
831861
return con
832862

833863

864+
async def _get_ssl_ready_socket(host, port, *, loop):
865+
reader, writer = await asyncio.open_connection(host, port, loop=loop)
866+
867+
tr = writer.transport
868+
try:
869+
sock = _get_socket(tr)
870+
_set_nodelay(sock)
871+
872+
writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
873+
await writer.drain()
874+
resp = await reader.readexactly(1)
875+
876+
if resp == b'S':
877+
return sock.dup()
878+
else:
879+
raise ConnectionError(
880+
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
881+
host, port))
882+
finally:
883+
tr.close()
884+
885+
886+
async def _create_ssl_connection(protocol_factory, host, port, *,
887+
loop, ssl_context):
888+
sock = await _get_ssl_ready_socket(host, port, loop=loop)
889+
try:
890+
return await loop.create_connection(
891+
protocol_factory, sock=sock, ssl=ssl_context,
892+
server_hostname=host)
893+
except Exception:
894+
sock.close()
895+
raise
896+
897+
898+
def _get_socket(transport):
899+
sock = transport.get_extra_info('socket')
900+
if sock is None:
901+
# Shouldn't happen with any asyncio-complaint event loop.
902+
raise ConnectionError(
903+
'could not get the socket for transport {!r}'.format(transport))
904+
return sock
905+
906+
907+
def _set_nodelay(sock):
908+
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
909+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
910+
911+
834912
class _StatementCacheEntry:
835913

836914
__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
@@ -1116,7 +1194,18 @@ def _parse_connect_params(*, dsn, host, port, user,
11161194
'invalid connection parameter {!r}: {!r} (str expected)'
11171195
.format(param, opts[param]))
11181196

1119-
return host, port, opts
1197+
addrs = []
1198+
for h in host:
1199+
if h.startswith('/'):
1200+
# UNIX socket name
1201+
if '.s.PGSQL.' not in h:
1202+
h = os.path.join(h, '.s.PGSQL.{}'.format(port))
1203+
addrs.append(h)
1204+
else:
1205+
# TCP host/port
1206+
addrs.append((h, port))
1207+
1208+
return addrs, opts
11201209

11211210

11221211
def _create_future(loop):

0 commit comments

Comments
 (0)