Skip to content

Commit 9f6839b

Browse files
committed
Add support for asynchronous iterables to copy_records_to_table()
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent a308a97 commit 9f6839b

File tree

3 files changed

+87
-23
lines changed

3 files changed

+87
-23
lines changed

asyncpg/connection.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ async def copy_records_to_table(self, table_name, *, records,
868868
869869
:param records:
870870
An iterable returning row tuples to copy into the table.
871+
:term:`Asynchronous iterables <python:asynchronous iterable>`
872+
are also supported.
871873
872874
:param list columns:
873875
An optional list of column names to copy.
@@ -897,7 +899,28 @@ async def copy_records_to_table(self, table_name, *, records,
897899
>>> asyncio.get_event_loop().run_until_complete(run())
898900
'COPY 2'
899901
902+
Asynchronous record iterables are also supported:
903+
904+
.. code-block:: pycon
905+
906+
>>> import asyncpg
907+
>>> import asyncio
908+
>>> async def run():
909+
... con = await asyncpg.connect(user='postgres')
910+
... async def record_gen(size):
911+
... for i in range(size):
912+
... yield (i,)
913+
... result = await con.copy_records_to_table(
914+
... 'mytable', records=record_gen(100))
915+
... print(result)
916+
...
917+
>>> asyncio.get_event_loop().run_until_complete(run())
918+
'COPY 100'
919+
900920
.. versionadded:: 0.11.0
921+
922+
.. versionchanged:: 0.23.0
923+
The ``records`` argument may be an asynchronous iterable.
901924
"""
902925
tabname = utils._quote_ident(table_name)
903926
if schema_name:
@@ -920,8 +943,8 @@ async def copy_records_to_table(self, table_name, *, records,
920943
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
921944
tab=tabname, cols=cols, opts=opts)
922945

923-
return await self._copy_in_records(
924-
copy_stmt, records, intro_ps._state, timeout)
946+
return await self._protocol.copy_in(
947+
copy_stmt, None, None, records, intro_ps._state, timeout)
925948

926949
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
927950
delimiter=None, null=None, header=None, quote=None,
@@ -1044,10 +1067,6 @@ async def __anext__(self):
10441067
if opened_by_us:
10451068
await run_in_executor(None, f.close)
10461069

1047-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1048-
return await self._protocol.copy_in(
1049-
copy_stmt, None, None, records, intro_stmt, timeout)
1050-
10511070
async def set_type_codec(self, typename, *,
10521071
schema='public', encoder, decoder,
10531072
format='text'):

asyncpg/protocol/protocol.pyx

+39-17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import asyncio
1414
import builtins
1515
import codecs
1616
import collections
17+
import collections.abc
1718
import socket
1819
import time
1920
import weakref
@@ -436,23 +437,44 @@ cdef class BaseProtocol(CoreProtocol):
436437
'no binary format encoder for '
437438
'type {} (OID {})'.format(codec.name, codec.oid))
438439

439-
for row in records:
440-
# Tuple header
441-
wbuf.write_int16(<int16_t>num_cols)
442-
# Tuple data
443-
for i in range(num_cols):
444-
item = row[i]
445-
if item is None:
446-
wbuf.write_int32(-1)
447-
else:
448-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
449-
codec.encode(settings, wbuf, item)
450-
451-
if wbuf.len() >= _COPY_BUFFER_SIZE:
452-
with timer:
453-
await self.writing_allowed.wait()
454-
self._write_copy_data_msg(wbuf)
455-
wbuf = WriteBuffer.new()
440+
if isinstance(records, collections.abc.AsyncIterable):
441+
async for row in records:
442+
# Tuple header
443+
wbuf.write_int16(<int16_t>num_cols)
444+
# Tuple data
445+
for i in range(num_cols):
446+
item = row[i]
447+
if item is None:
448+
wbuf.write_int32(-1)
449+
else:
450+
codec = <Codec>cpython.PyTuple_GET_ITEM(
451+
codecs, i)
452+
codec.encode(settings, wbuf, item)
453+
454+
if wbuf.len() >= _COPY_BUFFER_SIZE:
455+
with timer:
456+
await self.writing_allowed.wait()
457+
self._write_copy_data_msg(wbuf)
458+
wbuf = WriteBuffer.new()
459+
else:
460+
for row in records:
461+
# Tuple header
462+
wbuf.write_int16(<int16_t>num_cols)
463+
# Tuple data
464+
for i in range(num_cols):
465+
item = row[i]
466+
if item is None:
467+
wbuf.write_int32(-1)
468+
else:
469+
codec = <Codec>cpython.PyTuple_GET_ITEM(
470+
codecs, i)
471+
codec.encode(settings, wbuf, item)
472+
473+
if wbuf.len() >= _COPY_BUFFER_SIZE:
474+
with timer:
475+
await self.writing_allowed.wait()
476+
self._write_copy_data_msg(wbuf)
477+
wbuf = WriteBuffer.new()
456478

457479
# End of binary copy.
458480
wbuf.write_int16(-1)

tests/test_copy.py

+23
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,29 @@ async def test_copy_records_to_table_1(self):
649649
finally:
650650
await self.con.execute('DROP TABLE copytab')
651651

652+
async def test_copy_records_to_table_async(self):
653+
await self.con.execute('''
654+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
655+
''')
656+
657+
try:
658+
date = datetime.datetime.now(tz=datetime.timezone.utc)
659+
delta = datetime.timedelta(days=1)
660+
661+
async def record_generator():
662+
for i in range(100):
663+
yield ('a-{}'.format(i), i, date + delta)
664+
665+
yield ('a-100', None, None)
666+
667+
res = await self.con.copy_records_to_table(
668+
'copytab_async', records=record_generator())
669+
670+
self.assertEqual(res, 'COPY 101')
671+
672+
finally:
673+
await self.con.execute('DROP TABLE copytab_async')
674+
652675
async def test_copy_records_to_table_no_binary_codec(self):
653676
await self.con.execute('''
654677
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
 (0)