Skip to content

[RFC] Combine args of executemany() in batches #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ async def executemany(self, command: str, args, *, timeout: float=None):

.. versionchanged:: 0.11.0
`timeout` became a keyword-only parameter.

.. versionchanged:: 0.16.0
The execution was changed to be in a implicit transaction if there
was no explicit transaction, so that it will no longer end up with
partial success. It also combined all args into one network packet
to reduce round-trip time, therefore you should make sure not to
blow up your memory with a super long iterable. If you still need
the previous behavior to progressively execute many args, please use
prepared statement instead.
"""
self._check_open()
return await self._executemany(command, args, timeout)
Expand Down
22 changes: 22 additions & 0 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,28 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]

@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.

This combines all args into one network packet, thus reduces round-trip
time than executing one by one.

:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.

.. versionadded:: 0.16.0
"""
protocol = self._connection._protocol
try:
return await protocol.bind_execute_many(
self._state, args, '', timeout)
except exceptions.OutdatedSchemaCacheError:
await self._connection.reload_schema_state()
self._state.mark_closed()
raise

async def __bind_execute(self, args, limit, timeout):
protocol = self._connection._protocol
try:
Expand Down
5 changes: 0 additions & 5 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ cdef class CoreProtocol:
bint _skip_discard
bint _discard_data

# executemany support data
object _execute_iter
str _execute_portal_name
str _execute_stmt_name

ConnectionStatus con_status
ProtocolState state
TransactionStatus xact_status
Expand Down
65 changes: 30 additions & 35 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ cdef class CoreProtocol:

self._skip_discard = False

# executemany support data
self._execute_iter = None
self._execute_portal_name = None
self._execute_stmt_name = None

self._reset_result()

cdef _write(self, buf):
Expand Down Expand Up @@ -256,22 +251,7 @@ cdef class CoreProtocol:
elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
if self.result_type == RESULT_FAILED:
self._push_result()
else:
try:
buf = <WriteBuffer>next(self._execute_iter)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
else:
# Next iteration over the executemany() arg sequence
self._send_bind_message(
self._execute_portal_name, self._execute_stmt_name,
buf, 0)
self._push_result()

elif mtype == b'I':
# EmptyQueryResponse
Expand Down Expand Up @@ -799,27 +779,42 @@ cdef class CoreProtocol:
cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):

cdef WriteBuffer buf
cdef:
WriteBuffer packet
WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

packet = WriteBuffer.new()

self.result = None
self._discard_data = True
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name

try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
else:
self._send_bind_message(portal_name, stmt_name, buf, 0)
while True:
try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
if packet.len() > 0:
packet.write_bytes(SYNC_MESSAGE)
self.transport.write(memoryview(packet))
else:
self._push_result()
break
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
break
else:
buf = self._build_bind_message(portal_name, stmt_name, buf)
packet.write_buffer(buf)

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(0) # number of rows to return; 0 - all
buf.end_message()
packet.write_buffer(buf)

cdef _execute(self, str portal_name, int32_t limit):
cdef WriteBuffer buf
Expand Down
22 changes: 22 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,25 @@ async def test_execute_many_2(self):
''', good_data)
finally:
await self.con.execute('DROP TABLE exmany')

async def test_execute_many_atomic(self):
from asyncpg.exceptions import UniqueViolationError

await self.con.execute('CREATE TEMP TABLE exmany '
'(a text, b int PRIMARY KEY)')

try:
with self.assertRaises(UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [
('a', 1), ('b', 2), ('c', 2), ('d', 4)
])

result = await self.con.fetch('''
SELECT * FROM exmany
''')

self.assertEqual(result, [])
finally:
await self.con.execute('DROP TABLE exmany')
35 changes: 35 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,38 @@ async def test_prepare_does_not_use_cache(self):
# prepare with disabled cache
await self.con.prepare('select 1')
self.assertEqual(len(cache), 0)

async def test_prepare_executemany(self):
await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)')

try:
stmt = await self.con.prepare('''
INSERT INTO exmany VALUES($1, $2)
''')

result = await stmt.executemany([
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

self.assertIsNone(result)

result = await self.con.fetch('''
SELECT * FROM exmany
''')

self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
await stmt.executemany(())

result = await self.con.fetch('''
SELECT * FROM exmany
''')

self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
finally:
await self.con.execute('DROP TABLE exmany')