Skip to content

Commit 1d33ff6

Browse files
authored
Add support for asynchronous iterables to copy_records_to_table() (#713)
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent a6b0f28 commit 1d33ff6

File tree

3 files changed

+87
-24
lines changed

3 files changed

+87
-24
lines changed

asyncpg/connection.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ async def copy_records_to_table(self, table_name, *, records,
872872
873873
:param records:
874874
An iterable returning row tuples to copy into the table.
875+
:term:`Asynchronous iterables <python:asynchronous iterable>`
876+
are also supported.
875877
876878
:param list columns:
877879
An optional list of column names to copy.
@@ -901,7 +903,28 @@ async def copy_records_to_table(self, table_name, *, records,
901903
>>> asyncio.get_event_loop().run_until_complete(run())
902904
'COPY 2'
903905
906+
Asynchronous record iterables are also supported:
907+
908+
.. code-block:: pycon
909+
910+
>>> import asyncpg
911+
>>> import asyncio
912+
>>> async def run():
913+
... con = await asyncpg.connect(user='postgres')
914+
... async def record_gen(size):
915+
... for i in range(size):
916+
... yield (i,)
917+
... result = await con.copy_records_to_table(
918+
... 'mytable', records=record_gen(100))
919+
... print(result)
920+
...
921+
>>> asyncio.get_event_loop().run_until_complete(run())
922+
'COPY 100'
923+
904924
.. versionadded:: 0.11.0
925+
926+
.. versionchanged:: 0.24.0
927+
The ``records`` argument may be an asynchronous iterable.
905928
"""
906929
tabname = utils._quote_ident(table_name)
907930
if schema_name:
@@ -924,8 +947,8 @@ async def copy_records_to_table(self, table_name, *, records,
924947
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
925948
tab=tabname, cols=cols, opts=opts)
926949

927-
return await self._copy_in_records(
928-
copy_stmt, records, intro_ps._state, timeout)
950+
return await self._protocol.copy_in(
951+
copy_stmt, None, None, records, intro_ps._state, timeout)
929952

930953
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
931954
delimiter=None, null=None, header=None, quote=None,
@@ -1047,10 +1070,6 @@ async def __anext__(self):
10471070
if opened_by_us:
10481071
await run_in_executor(None, f.close)
10491072

1050-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1051-
return await self._protocol.copy_in(
1052-
copy_stmt, None, None, records, intro_stmt, timeout)
1053-
10541073
async def set_type_codec(self, typename, *,
10551074
schema='public', encoder, decoder,
10561075
format='text'):

asyncpg/protocol/protocol.pyx

+39-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cimport cpython
1313
import asyncio
1414
import builtins
1515
import codecs
16-
import collections
16+
import collections.abc
1717
import socket
1818
import time
1919
import weakref
@@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol):
438438
'no binary format encoder for '
439439
'type {} (OID {})'.format(codec.name, codec.oid))
440440

441-
for row in records:
442-
# Tuple header
443-
wbuf.write_int16(<int16_t>num_cols)
444-
# Tuple data
445-
for i in range(num_cols):
446-
item = row[i]
447-
if item is None:
448-
wbuf.write_int32(-1)
449-
else:
450-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
451-
codec.encode(settings, wbuf, item)
452-
453-
if wbuf.len() >= _COPY_BUFFER_SIZE:
454-
with timer:
455-
await self.writing_allowed.wait()
456-
self._write_copy_data_msg(wbuf)
457-
wbuf = WriteBuffer.new()
441+
if isinstance(records, collections.abc.AsyncIterable):
442+
async for row in records:
443+
# Tuple header
444+
wbuf.write_int16(<int16_t>num_cols)
445+
# Tuple data
446+
for i in range(num_cols):
447+
item = row[i]
448+
if item is None:
449+
wbuf.write_int32(-1)
450+
else:
451+
codec = <Codec>cpython.PyTuple_GET_ITEM(
452+
codecs, i)
453+
codec.encode(settings, wbuf, item)
454+
455+
if wbuf.len() >= _COPY_BUFFER_SIZE:
456+
with timer:
457+
await self.writing_allowed.wait()
458+
self._write_copy_data_msg(wbuf)
459+
wbuf = WriteBuffer.new()
460+
else:
461+
for row in records:
462+
# Tuple header
463+
wbuf.write_int16(<int16_t>num_cols)
464+
# Tuple data
465+
for i in range(num_cols):
466+
item = row[i]
467+
if item is None:
468+
wbuf.write_int32(-1)
469+
else:
470+
codec = <Codec>cpython.PyTuple_GET_ITEM(
471+
codecs, i)
472+
codec.encode(settings, wbuf, item)
473+
474+
if wbuf.len() >= _COPY_BUFFER_SIZE:
475+
with timer:
476+
await self.writing_allowed.wait()
477+
self._write_copy_data_msg(wbuf)
478+
wbuf = WriteBuffer.new()
458479

459480
# End of binary copy.
460481
wbuf.write_int16(-1)

tests/test_copy.py

+23
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,29 @@ async def test_copy_records_to_table_1(self):
644644
finally:
645645
await self.con.execute('DROP TABLE copytab')
646646

647+
async def test_copy_records_to_table_async(self):
648+
await self.con.execute('''
649+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
650+
''')
651+
652+
try:
653+
date = datetime.datetime.now(tz=datetime.timezone.utc)
654+
delta = datetime.timedelta(days=1)
655+
656+
async def record_generator():
657+
for i in range(100):
658+
yield ('a-{}'.format(i), i, date + delta)
659+
660+
yield ('a-100', None, None)
661+
662+
res = await self.con.copy_records_to_table(
663+
'copytab_async', records=record_generator())
664+
665+
self.assertEqual(res, 'COPY 101')
666+
667+
finally:
668+
await self.con.execute('DROP TABLE copytab_async')
669+
647670
async def test_copy_records_to_table_no_binary_codec(self):
648671
await self.con.execute('''
649672
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
 (0)