Skip to content

Commit 4b4407c

Browse files
authored
Merge pull request #4 from MagicStack/master
Sync Fork from Upstream Repo
2 parents 888f68e + 7f5c2a2 commit 4b4407c

File tree

5 files changed

+102
-3
lines changed

5 files changed

+102
-3
lines changed

asyncpg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131
# snapshots will automatically include the git revision
3232
# in __version__, for example: '0.16.0.dev0+ge06ad03'
3333

34-
__version__ = '0.20.1'
34+
__version__ = '0.21.0.dev0'

asyncpg/connect_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import typing
2222
import urllib.parse
2323
import warnings
24+
import inspect
2425

2526
from . import compat
2627
from . import exceptions
@@ -601,6 +602,16 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
601602
raise asyncio.TimeoutError
602603

603604
connected = _create_future(loop)
605+
606+
params_input = params
607+
if callable(params.password):
608+
if inspect.iscoroutinefunction(params.password):
609+
password = await params.password()
610+
else:
611+
password = params.password()
612+
613+
params = params._replace(password=password)
614+
604615
proto_factory = lambda: protocol.Protocol(
605616
addr, connected, params, loop)
606617

@@ -633,7 +644,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
633644
tr.close()
634645
raise
635646

636-
con = connection_class(pr, tr, loop, addr, config, params)
647+
con = connection_class(pr, tr, loop, addr, config, params_input)
637648
pr.set_connection(con)
638649
return con
639650

asyncpg/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ async def _copy_in(self, copy_stmt, source, timeout):
813813

814814
if path is not None:
815815
# a path
816-
f = await run_in_executor(None, open, path, 'wb')
816+
f = await run_in_executor(None, open, path, 'rb')
817817
opened_by_us = True
818818
elif hasattr(source, 'read'):
819819
# file-like
@@ -1566,6 +1566,10 @@ async def connect(dsn=None, *,
15661566
other users and applications may be able to read it without needing
15671567
specific privileges. It is recommended to use *passfile* instead.
15681568
1569+
Password may be either a string, or a callable that returns a string.
1570+
If a callable is provided, it will be called each time a new connection
1571+
is established.
1572+
15691573
:param passfile:
15701574
The name of the file used to store passwords
15711575
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
@@ -1646,6 +1650,9 @@ async def connect(dsn=None, *,
16461650
Added ability to specify multiple hosts in the *dsn*
16471651
and *host* arguments.
16481652
1653+
.. versionchanged:: 0.21.0
1654+
The *password* argument now accepts a callable or an async function.
1655+
16491656
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
16501657
.. _create_default_context:
16511658
https://docs.python.org/3/library/ssl.html#ssl.create_default_context

tests/test_connect.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,44 @@ async def test_auth_password_cleartext(self):
204204
user='password_user',
205205
password='wrongpassword')
206206

207+
async def test_auth_password_cleartext_callable(self):
208+
def get_correctpassword():
209+
return 'correctpassword'
210+
211+
def get_wrongpassword():
212+
return 'wrongpassword'
213+
214+
conn = await self.connect(
215+
user='password_user',
216+
password=get_correctpassword)
217+
await conn.close()
218+
219+
with self.assertRaisesRegex(
220+
asyncpg.InvalidPasswordError,
221+
'password authentication failed for user "password_user"'):
222+
await self._try_connect(
223+
user='password_user',
224+
password=get_wrongpassword)
225+
226+
async def test_auth_password_cleartext_callable_coroutine(self):
227+
async def get_correctpassword():
228+
return 'correctpassword'
229+
230+
async def get_wrongpassword():
231+
return 'wrongpassword'
232+
233+
conn = await self.connect(
234+
user='password_user',
235+
password=get_correctpassword)
236+
await conn.close()
237+
238+
with self.assertRaisesRegex(
239+
asyncpg.InvalidPasswordError,
240+
'password authentication failed for user "password_user"'):
241+
await self._try_connect(
242+
user='password_user',
243+
password=get_wrongpassword)
244+
207245
async def test_auth_password_md5(self):
208246
conn = await self.connect(
209247
user='md5_user', password='correctpassword')

tests/test_copy.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import datetime
1010
import io
11+
import os
1112
import tempfile
1213

1314
import asyncpg
@@ -582,6 +583,48 @@ async def __anext__(self):
582583
finally:
583584
await self.con.execute('DROP TABLE copytab')
584585

586+
async def test_copy_to_table_from_file_path(self):
587+
await self.con.execute('''
588+
CREATE TABLE copytab(a text, "b~" text, i int);
589+
''')
590+
591+
f = tempfile.NamedTemporaryFile(delete=False)
592+
try:
593+
f.write(
594+
'\n'.join([
595+
'a1\tb1\t1',
596+
'a2\tb2\t2',
597+
'a3\tb3\t3',
598+
'a4\tb4\t4',
599+
'a5\tb5\t5',
600+
'*\t\\N\t\\N',
601+
''
602+
]).encode('utf-8')
603+
)
604+
f.close()
605+
606+
res = await self.con.copy_to_table('copytab', source=f.name)
607+
self.assertEqual(res, 'COPY 6')
608+
609+
output = await self.con.fetch("""
610+
SELECT * FROM copytab ORDER BY a
611+
""")
612+
self.assertEqual(
613+
output,
614+
[
615+
('*', None, None),
616+
('a1', 'b1', 1),
617+
('a2', 'b2', 2),
618+
('a3', 'b3', 3),
619+
('a4', 'b4', 4),
620+
('a5', 'b5', 5),
621+
]
622+
)
623+
624+
finally:
625+
await self.con.execute('DROP TABLE public.copytab')
626+
os.unlink(f.name)
627+
585628
async def test_copy_records_to_table_1(self):
586629
await self.con.execute('''
587630
CREATE TABLE copytab(a text, b int, c timestamptz);

0 commit comments

Comments
 (0)