Skip to content

Commit f26cd38

Browse files
committed
Add support for asynchronous iterables to copy_records_to_table()
The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689.
1 parent 53bea98 commit f26cd38

File tree

3 files changed

+90
-24
lines changed

3 files changed

+90
-24
lines changed

asyncpg/connection.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ async def copy_records_to_table(self, table_name, *, records,
868868
869869
:param records:
870870
An iterable returning row tuples to copy into the table.
871+
:term:`Asynchronous iterables <python:asynchronous iterable>`
872+
are also supported.
871873
872874
:param list columns:
873875
An optional list of column names to copy.
@@ -897,7 +899,28 @@ async def copy_records_to_table(self, table_name, *, records,
897899
>>> asyncio.get_event_loop().run_until_complete(run())
898900
'COPY 2'
899901
902+
Asynchronous record iterables are also supported:
903+
904+
.. code-block:: pycon
905+
906+
>>> import asyncpg
907+
>>> import asyncio
908+
>>> async def run():
909+
... con = await asyncpg.connect(user='postgres')
910+
... async def record_gen(size):
911+
... for i in range(size):
912+
... yield (i,)
913+
... result = await con.copy_records_to_table(
914+
... 'mytable', records=record_gen(100))
915+
... print(result)
916+
...
917+
>>> asyncio.get_event_loop().run_until_complete(run())
918+
'COPY 100'
919+
900920
.. versionadded:: 0.11.0
921+
922+
.. versionchanged:: 0.23.0
923+
The ``records`` argument may be an asynchronous iterable.
901924
"""
902925
tabname = utils._quote_ident(table_name)
903926
if schema_name:
@@ -920,8 +943,8 @@ async def copy_records_to_table(self, table_name, *, records,
920943
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
921944
tab=tabname, cols=cols, opts=opts)
922945

923-
return await self._copy_in_records(
924-
copy_stmt, records, intro_ps._state, timeout)
946+
return await self._protocol.copy_in(
947+
copy_stmt, None, None, records, intro_ps._state, timeout)
925948

926949
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
927950
delimiter=None, null=None, header=None, quote=None,
@@ -1044,10 +1067,6 @@ async def __anext__(self):
10441067
if opened_by_us:
10451068
await run_in_executor(None, f.close)
10461069

1047-
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
1048-
return await self._protocol.copy_in(
1049-
copy_stmt, None, None, records, intro_stmt, timeout)
1050-
10511070
async def set_type_codec(self, typename, *,
10521071
schema='public', encoder, decoder,
10531072
format='text'):

asyncpg/protocol/protocol.pyx

+39-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cimport cpython
1313
import asyncio
1414
import builtins
1515
import codecs
16-
import collections
16+
import collections.abc
1717
import socket
1818
import time
1919
import weakref
@@ -436,23 +436,44 @@ cdef class BaseProtocol(CoreProtocol):
436436
'no binary format encoder for '
437437
'type {} (OID {})'.format(codec.name, codec.oid))
438438

439-
for row in records:
440-
# Tuple header
441-
wbuf.write_int16(<int16_t>num_cols)
442-
# Tuple data
443-
for i in range(num_cols):
444-
item = row[i]
445-
if item is None:
446-
wbuf.write_int32(-1)
447-
else:
448-
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
449-
codec.encode(settings, wbuf, item)
450-
451-
if wbuf.len() >= _COPY_BUFFER_SIZE:
452-
with timer:
453-
await self.writing_allowed.wait()
454-
self._write_copy_data_msg(wbuf)
455-
wbuf = WriteBuffer.new()
439+
if isinstance(records, collections.abc.AsyncIterable):
440+
async for row in records:
441+
# Tuple header
442+
wbuf.write_int16(<int16_t>num_cols)
443+
# Tuple data
444+
for i in range(num_cols):
445+
item = row[i]
446+
if item is None:
447+
wbuf.write_int32(-1)
448+
else:
449+
codec = <Codec>cpython.PyTuple_GET_ITEM(
450+
codecs, i)
451+
codec.encode(settings, wbuf, item)
452+
453+
if wbuf.len() >= _COPY_BUFFER_SIZE:
454+
with timer:
455+
await self.writing_allowed.wait()
456+
self._write_copy_data_msg(wbuf)
457+
wbuf = WriteBuffer.new()
458+
else:
459+
for row in records:
460+
# Tuple header
461+
wbuf.write_int16(<int16_t>num_cols)
462+
# Tuple data
463+
for i in range(num_cols):
464+
item = row[i]
465+
if item is None:
466+
wbuf.write_int32(-1)
467+
else:
468+
codec = <Codec>cpython.PyTuple_GET_ITEM(
469+
codecs, i)
470+
codec.encode(settings, wbuf, item)
471+
472+
if wbuf.len() >= _COPY_BUFFER_SIZE:
473+
with timer:
474+
await self.writing_allowed.wait()
475+
self._write_copy_data_msg(wbuf)
476+
wbuf = WriteBuffer.new()
456477

457478
# End of binary copy.
458479
wbuf.write_int16(-1)

tests/test_copy.py

+26
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import datetime
1010
import io
1111
import os
12+
import sys
1213
import tempfile
14+
import unittest
1315

1416
import asyncpg
1517
from asyncpg import _testbase as tb
@@ -649,6 +651,30 @@ async def test_copy_records_to_table_1(self):
649651
finally:
650652
await self.con.execute('DROP TABLE copytab')
651653

654+
@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
655+
async def test_copy_records_to_table_async(self):
656+
await self.con.execute('''
657+
CREATE TABLE copytab_async(a text, b int, c timestamptz);
658+
''')
659+
660+
try:
661+
date = datetime.datetime.now(tz=datetime.timezone.utc)
662+
delta = datetime.timedelta(days=1)
663+
664+
async def record_generator():
665+
for i in range(100):
666+
yield ('a-{}'.format(i), i, date + delta)
667+
668+
yield ('a-100', None, None)
669+
670+
res = await self.con.copy_records_to_table(
671+
'copytab_async', records=record_generator())
672+
673+
self.assertEqual(res, 'COPY 101')
674+
675+
finally:
676+
await self.con.execute('DROP TABLE copytab_async')
677+
652678
async def test_copy_records_to_table_no_binary_codec(self):
653679
await self.con.execute('''
654680
CREATE TABLE copytab(a uuid);

0 commit comments

Comments
 (0)