Skip to content

Commit 1d9457f

Browse files
Harvey Fryeelprans
Harvey Frye
authored andcommitted
Add support for password functions (useful for RDS IAM auth) (#554)
Closes: #554 Closes: #553
1 parent 1d4325c commit 1d9457f

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

asyncpg/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131
# snapshots will automatically include the git revision
3232
# in __version__, for example: '0.16.0.dev0+ge06ad03'
3333

34-
__version__ = '0.20.1'
34+
__version__ = '0.21.0.dev0'

asyncpg/connect_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import typing
2222
import urllib.parse
2323
import warnings
24+
import inspect
2425

2526
from . import compat
2627
from . import exceptions
@@ -601,6 +602,16 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
601602
raise asyncio.TimeoutError
602603

603604
connected = _create_future(loop)
605+
606+
params_input = params
607+
if callable(params.password):
608+
if inspect.iscoroutinefunction(params.password):
609+
password = await params.password()
610+
else:
611+
password = params.password()
612+
613+
params = params._replace(password=password)
614+
604615
proto_factory = lambda: protocol.Protocol(
605616
addr, connected, params, loop)
606617

@@ -633,7 +644,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
633644
tr.close()
634645
raise
635646

636-
con = connection_class(pr, tr, loop, addr, config, params)
647+
con = connection_class(pr, tr, loop, addr, config, params_input)
637648
pr.set_connection(con)
638649
return con
639650

asyncpg/connection.py

+7
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,10 @@ async def connect(dsn=None, *,
15661566
other users and applications may be able to read it without needing
15671567
specific privileges. It is recommended to use *passfile* instead.
15681568
1569+
Password may be either a string, or a callable that returns a string.
1570+
If a callable is provided, it will be called each time a new connection
1571+
is established.
1572+
15691573
:param passfile:
15701574
The name of the file used to store passwords
15711575
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
@@ -1646,6 +1650,9 @@ async def connect(dsn=None, *,
16461650
Added ability to specify multiple hosts in the *dsn*
16471651
and *host* arguments.
16481652
1653+
.. versionchanged:: 0.21.0
1654+
The *password* argument now accepts a callable or an async function.
1655+
16491656
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
16501657
.. _create_default_context:
16511658
https://docs.python.org/3/library/ssl.html#ssl.create_default_context

tests/test_connect.py

+38
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,44 @@ async def test_auth_password_cleartext(self):
204204
user='password_user',
205205
password='wrongpassword')
206206

207+
async def test_auth_password_cleartext_callable(self):
208+
def get_correctpassword():
209+
return 'correctpassword'
210+
211+
def get_wrongpassword():
212+
return 'wrongpassword'
213+
214+
conn = await self.connect(
215+
user='password_user',
216+
password=get_correctpassword)
217+
await conn.close()
218+
219+
with self.assertRaisesRegex(
220+
asyncpg.InvalidPasswordError,
221+
'password authentication failed for user "password_user"'):
222+
await self._try_connect(
223+
user='password_user',
224+
password=get_wrongpassword)
225+
226+
async def test_auth_password_cleartext_callable_coroutine(self):
227+
async def get_correctpassword():
228+
return 'correctpassword'
229+
230+
async def get_wrongpassword():
231+
return 'wrongpassword'
232+
233+
conn = await self.connect(
234+
user='password_user',
235+
password=get_correctpassword)
236+
await conn.close()
237+
238+
with self.assertRaisesRegex(
239+
asyncpg.InvalidPasswordError,
240+
'password authentication failed for user "password_user"'):
241+
await self._try_connect(
242+
user='password_user',
243+
password=get_wrongpassword)
244+
207245
async def test_auth_password_md5(self):
208246
conn = await self.connect(
209247
user='md5_user', password='correctpassword')

0 commit comments

Comments
 (0)