Skip to content

Commit 4be5198

Browse files
committed
Rewrite (again) executemany() to batch args for performance.
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take advantage of `writelines()`. A single `Sync` is sent at last, so that all args live in the same transaction. Closes: #289
1 parent 32fccaa commit 4be5198

File tree

7 files changed

+320
-112
lines changed

7 files changed

+320
-112
lines changed

asyncpg/connection.py

+10
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
301301
302302
.. versionchanged:: 0.11.0
303303
`timeout` became a keyword-only parameter.
304+
305+
.. versionchanged:: 0.19.0
306+
The execution was changed to be in an implicit transaction if there
307+
was no explicit transaction, so that it will no longer end up with
308+
partial success. If you still need the previous behavior to
309+
progressively execute many args, please use a loop with prepared
310+
statement instead.
304311
"""
305312
self._check_open()
306313
return await self._executemany(command, args, timeout)
@@ -821,6 +828,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
821828
f = source
822829
elif isinstance(source, collections.abc.AsyncIterable):
823830
# assuming calling output returns an awaitable.
831+
# copy_in() is designed to handle very large amounts of data, and
832+
# the source async iterable is allowed to return an arbitrary
833+
# amount of data on every iteration.
824834
reader = source
825835
else:
826836
# assuming source is an instance supporting the buffer protocol.

asyncpg/prepared_stmt.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,24 @@ async def fetchrow(self, *args, timeout=None):
196196
return None
197197
return data[0]
198198

199-
async def __bind_execute(self, args, limit, timeout):
199+
@connresource.guarded
200+
async def executemany(self, args, *, timeout: float=None):
201+
"""Execute the statement for each sequence of arguments in *args*.
202+
203+
:param args: An iterable containing sequences of arguments.
204+
:param float timeout: Optional timeout value in seconds.
205+
:return None: This method discards the results of the operations.
206+
207+
.. versionadded:: 0.19.0
208+
"""
209+
return await self.__do_execute(
210+
lambda protocol: protocol.bind_execute_many(
211+
self._state, args, '', timeout))
212+
213+
async def __do_execute(self, executor):
200214
protocol = self._connection._protocol
201215
try:
202-
data, status, _ = await protocol.bind_execute(
203-
self._state, args, '', limit, True, timeout)
216+
return await executor(protocol)
204217
except exceptions.OutdatedSchemaCacheError:
205218
await self._connection.reload_schema_state()
206219
# We can not find all manually created prepared statements, so just
@@ -209,6 +222,11 @@ async def __bind_execute(self, args, limit, timeout):
209222
# invalidate themselves (unfortunately, clearing caches again).
210223
self._state.mark_closed()
211224
raise
225+
226+
async def __bind_execute(self, args, limit, timeout):
227+
data, status, _ = await self.__do_execute(
228+
lambda protocol: protocol.bind_execute(
229+
self._state, args, '', limit, True, timeout))
212230
self._last_status = status
213231
return data
214232

asyncpg/protocol/consts.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
DEF _MAXINT32 = 2**31 - 1
99
DEF _COPY_BUFFER_SIZE = 524288
1010
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
11+
DEF _EXECUTE_MANY_BUF_NUM = 4
12+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

+11-7
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ cdef class CoreProtocol:
8383
bint _skip_discard
8484
bint _discard_data
8585

86-
# executemany support data
87-
object _execute_iter
88-
str _execute_portal_name
89-
str _execute_stmt_name
90-
9186
ConnectionStatus con_status
9287
ProtocolState state
9388
TransactionStatus xact_status
@@ -114,6 +109,7 @@ cdef class CoreProtocol:
114109
# True - completed, False - suspended
115110
bint result_execute_completed
116111

112+
cpdef is_in_transaction(self)
117113
cdef _process__auth(self, char mtype)
118114
cdef _process__prepare(self, char mtype)
119115
cdef _process__bind_execute(self, char mtype)
@@ -146,6 +142,7 @@ cdef class CoreProtocol:
146142
cdef _auth_password_message_sasl_continue(self, bytes server_response)
147143

148144
cdef _write(self, buf)
145+
cdef _writelines(self, list buffers)
149146

150147
cdef _read_server_messages(self)
151148

@@ -155,9 +152,13 @@ cdef class CoreProtocol:
155152

156153
cdef _ensure_connected(self)
157154

155+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
158156
cdef WriteBuffer _build_bind_message(self, str portal_name,
159157
str stmt_name,
160158
WriteBuffer bind_data)
159+
cdef WriteBuffer _build_empty_bind_data(self)
160+
cdef WriteBuffer _build_execute_message(self, str portal_name,
161+
int32_t limit)
161162

162163

163164
cdef _connect(self)
@@ -166,8 +167,11 @@ cdef class CoreProtocol:
166167
WriteBuffer bind_data, int32_t limit)
167168
cdef _bind_execute(self, str portal_name, str stmt_name,
168169
WriteBuffer bind_data, int32_t limit)
169-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
170-
object bind_data)
170+
cdef _execute_many_init(self)
171+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
172+
object bind_data)
173+
cdef _execute_many_done(self, bint data_sent)
174+
cdef _execute_many_fail(self, object error)
171175
cdef _bind(self, str portal_name, str stmt_name,
172176
WriteBuffer bind_data)
173177
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

+103-51
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ cdef class CoreProtocol:
2727
# type of `scram` is `SCRAMAuthentcation`
2828
self.scram = None
2929

30-
# executemany support data
31-
self._execute_iter = None
32-
self._execute_portal_name = None
33-
self._execute_stmt_name = None
34-
3530
self._reset_result()
3631

32+
cpdef is_in_transaction(self):
33+
# PQTRANS_INTRANS = idle, within transaction block
34+
# PQTRANS_INERROR = idle, within failed transaction
35+
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
36+
3737
cdef _read_server_messages(self):
3838
cdef:
3939
char mtype
@@ -258,22 +258,7 @@ cdef class CoreProtocol:
258258
elif mtype == b'Z':
259259
# ReadyForQuery
260260
self._parse_msg_ready_for_query()
261-
if self.result_type == RESULT_FAILED:
262-
self._push_result()
263-
else:
264-
try:
265-
buf = <WriteBuffer>next(self._execute_iter)
266-
except StopIteration:
267-
self._push_result()
268-
except Exception as e:
269-
self.result_type = RESULT_FAILED
270-
self.result = e
271-
self._push_result()
272-
else:
273-
# Next iteration over the executemany() arg sequence
274-
self._send_bind_message(
275-
self._execute_portal_name, self._execute_stmt_name,
276-
buf, 0)
261+
self._push_result()
277262

278263
elif mtype == b'I':
279264
# EmptyQueryResponse
@@ -775,6 +760,17 @@ cdef class CoreProtocol:
775760
if self.con_status != CONNECTION_OK:
776761
raise apg_exc.InternalClientError('not connected')
777762

763+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
764+
cdef WriteBuffer buf
765+
766+
buf = WriteBuffer.new_message(b'P')
767+
buf.write_str(stmt_name, self.encoding)
768+
buf.write_str(query, self.encoding)
769+
buf.write_int16(0)
770+
771+
buf.end_message()
772+
return buf
773+
778774
cdef WriteBuffer _build_bind_message(self, str portal_name,
779775
str stmt_name,
780776
WriteBuffer bind_data):
@@ -790,6 +786,25 @@ cdef class CoreProtocol:
790786
buf.end_message()
791787
return buf
792788

789+
cdef WriteBuffer _build_empty_bind_data(self):
790+
cdef WriteBuffer buf
791+
buf = WriteBuffer.new()
792+
buf.write_int16(0) # The number of parameter format codes
793+
buf.write_int16(0) # The number of parameter values
794+
buf.write_int16(0) # The number of result-column format codes
795+
return buf
796+
797+
cdef WriteBuffer _build_execute_message(self, str portal_name,
798+
int32_t limit):
799+
cdef WriteBuffer buf
800+
801+
buf = WriteBuffer.new_message(b'E')
802+
buf.write_str(portal_name, self.encoding) # name of the portal
803+
buf.write_int32(limit) # number of rows to return; 0 - all
804+
805+
buf.end_message()
806+
return buf
807+
793808
# API for subclasses
794809

795810
cdef _connect(self):
@@ -840,12 +855,7 @@ cdef class CoreProtocol:
840855
self._ensure_connected()
841856
self._set_state(PROTOCOL_PREPARE)
842857

843-
buf = WriteBuffer.new_message(b'P')
844-
buf.write_str(stmt_name, self.encoding)
845-
buf.write_str(query, self.encoding)
846-
buf.write_int16(0)
847-
buf.end_message()
848-
packet = buf
858+
packet = self._build_parse_message(stmt_name, query)
849859

850860
buf = WriteBuffer.new_message(b'D')
851861
buf.write_byte(b'S')
@@ -867,10 +877,7 @@ cdef class CoreProtocol:
867877
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
868878
packet = buf
869879

870-
buf = WriteBuffer.new_message(b'E')
871-
buf.write_str(portal_name, self.encoding) # name of the portal
872-
buf.write_int32(limit) # number of rows to return; 0 - all
873-
buf.end_message()
880+
buf = self._build_execute_message(portal_name, limit)
874881
packet.write_buffer(buf)
875882

876883
packet.write_bytes(SYNC_MESSAGE)
@@ -889,30 +896,75 @@ cdef class CoreProtocol:
889896

890897
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
891898

892-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
893-
object bind_data):
894-
895-
cdef WriteBuffer buf
896-
899+
cdef _execute_many_init(self):
897900
self._ensure_connected()
898901
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
899902

900903
self.result = None
901904
self._discard_data = True
902-
self._execute_iter = bind_data
903-
self._execute_portal_name = portal_name
904-
self._execute_stmt_name = stmt_name
905905

906-
try:
907-
buf = <WriteBuffer>next(bind_data)
908-
except StopIteration:
909-
self._push_result()
910-
except Exception as e:
911-
self.result_type = RESULT_FAILED
912-
self.result = e
906+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
907+
object bind_data):
908+
cdef:
909+
WriteBuffer packet
910+
WriteBuffer buf
911+
list buffers = []
912+
913+
if self.result_type == RESULT_FAILED:
914+
raise StopIteration(True)
915+
916+
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
917+
packet = WriteBuffer.new()
918+
919+
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
920+
try:
921+
buf = <WriteBuffer>next(bind_data)
922+
except StopIteration:
923+
if packet.len() > 0:
924+
buffers.append(packet)
925+
if len(buffers) > 0:
926+
self._writelines(buffers)
927+
raise StopIteration(True)
928+
else:
929+
raise StopIteration(False)
930+
except Exception as ex:
931+
raise StopIteration(ex)
932+
packet.write_buffer(
933+
self._build_bind_message(portal_name, stmt_name, buf))
934+
packet.write_buffer(
935+
self._build_execute_message(portal_name, 0))
936+
buffers.append(packet)
937+
self._writelines(buffers)
938+
939+
cdef _execute_many_done(self, bint data_sent):
940+
if data_sent:
941+
self._write(SYNC_MESSAGE)
942+
else:
913943
self._push_result()
944+
945+
cdef _execute_many_fail(self, object error):
946+
cdef WriteBuffer buf
947+
948+
self.result_type = RESULT_FAILED
949+
self.result = error
950+
951+
# We shall rollback in an implicit transaction to prevent partial
952+
# commit, while do nothing in an explicit transaction and leaving the
953+
# error to the user
954+
if self.is_in_transaction():
955+
self._execute_many_done(True)
914956
else:
915-
self._send_bind_message(portal_name, stmt_name, buf, 0)
957+
# Here if the implicit transaction is in `ignore_till_sync` mode,
958+
# the `ROLLBACK` will be ignored and `Sync` will restore the state;
959+
# or else the implicit transaction will be rolled back with a
960+
# warning saying that there was no transaction, but rollback is
961+
# done anyway, so we could ignore this warning.
962+
buf = self._build_parse_message('', 'ROLLBACK')
963+
buf.write_buffer(self._build_bind_message(
964+
'', '', self._build_empty_bind_data()))
965+
buf.write_buffer(self._build_execute_message('', 0))
966+
buf.write_bytes(SYNC_MESSAGE)
967+
self._write(buf)
916968

917969
cdef _execute(self, str portal_name, int32_t limit):
918970
cdef WriteBuffer buf
@@ -922,10 +974,7 @@ cdef class CoreProtocol:
922974

923975
self.result = []
924976

925-
buf = WriteBuffer.new_message(b'E')
926-
buf.write_str(portal_name, self.encoding) # name of the portal
927-
buf.write_int32(limit) # number of rows to return; 0 - all
928-
buf.end_message()
977+
buf = self._build_execute_message(portal_name, limit)
929978

930979
buf.write_bytes(SYNC_MESSAGE)
931980

@@ -1008,6 +1057,9 @@ cdef class CoreProtocol:
10081057
cdef _write(self, buf):
10091058
raise NotImplementedError
10101059

1060+
cdef _writelines(self, list buffers):
1061+
raise NotImplementedError
1062+
10111063
cdef _decode_row(self, const char* buf, ssize_t buf_len):
10121064
pass
10131065

0 commit comments

Comments
 (0)