Skip to content

Commit a3bfcff

Browse files
committed
Rewrite (again) executemany() to batch args for performance.
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take advantage of `writelines()`. A single `Sync` is sent at last, so that all args live in the same transaction. Closes: #289
1 parent 92aa806 commit a3bfcff

File tree

7 files changed

+320
-112
lines changed

7 files changed

+320
-112
lines changed

asyncpg/connection.py

+10
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
331331
332332
.. versionchanged:: 0.11.0
333333
`timeout` became a keyword-only parameter.
334+
335+
.. versionchanged:: 0.19.0
336+
The execution was changed to be in an implicit transaction if there
337+
was no explicit transaction, so that it will no longer end up with
338+
partial success. If you still need the previous behavior to
339+
progressively execute many args, please use a loop with prepared
340+
statement instead.
334341
"""
335342
self._check_open()
336343
return await self._executemany(command, args, timeout)
@@ -1010,6 +1017,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
10101017
f = source
10111018
elif isinstance(source, collections.abc.AsyncIterable):
10121019
# assuming calling output returns an awaitable.
1020+
# copy_in() is designed to handle very large amounts of data, and
1021+
# the source async iterable is allowed to return an arbitrary
1022+
# amount of data on every iteration.
10131023
reader = source
10141024
else:
10151025
# assuming source is an instance supporting the buffer protocol.

asyncpg/prepared_stmt.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,24 @@ async def fetchrow(self, *args, timeout=None):
202202
return None
203203
return data[0]
204204

205-
async def __bind_execute(self, args, limit, timeout):
205+
@connresource.guarded
206+
async def executemany(self, args, *, timeout: float=None):
207+
"""Execute the statement for each sequence of arguments in *args*.
208+
209+
:param args: An iterable containing sequences of arguments.
210+
:param float timeout: Optional timeout value in seconds.
211+
:return None: This method discards the results of the operations.
212+
213+
.. versionadded:: 0.19.0
214+
"""
215+
return await self.__do_execute(
216+
lambda protocol: protocol.bind_execute_many(
217+
self._state, args, '', timeout))
218+
219+
async def __do_execute(self, executor):
206220
protocol = self._connection._protocol
207221
try:
208-
data, status, _ = await protocol.bind_execute(
209-
self._state, args, '', limit, True, timeout)
222+
return await executor(protocol)
210223
except exceptions.OutdatedSchemaCacheError:
211224
await self._connection.reload_schema_state()
212225
# We can not find all manually created prepared statements, so just
@@ -215,6 +228,11 @@ async def __bind_execute(self, args, limit, timeout):
215228
# invalidate themselves (unfortunately, clearing caches again).
216229
self._state.mark_closed()
217230
raise
231+
232+
async def __bind_execute(self, args, limit, timeout):
233+
data, status, _ = await self.__do_execute(
234+
lambda protocol: protocol.bind_execute(
235+
self._state, args, '', limit, True, timeout))
218236
self._last_status = status
219237
return data
220238

asyncpg/protocol/consts.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
DEF _MAXINT32 = 2**31 - 1
99
DEF _COPY_BUFFER_SIZE = 524288
1010
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
11+
DEF _EXECUTE_MANY_BUF_NUM = 4
12+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

+11-7
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ cdef class CoreProtocol:
8383
bint _skip_discard
8484
bint _discard_data
8585

86-
# executemany support data
87-
object _execute_iter
88-
str _execute_portal_name
89-
str _execute_stmt_name
90-
9186
ConnectionStatus con_status
9287
ProtocolState state
9388
TransactionStatus xact_status
@@ -114,6 +109,7 @@ cdef class CoreProtocol:
114109
# True - completed, False - suspended
115110
bint result_execute_completed
116111

112+
cpdef is_in_transaction(self)
117113
cdef _process__auth(self, char mtype)
118114
cdef _process__prepare(self, char mtype)
119115
cdef _process__bind_execute(self, char mtype)
@@ -146,6 +142,7 @@ cdef class CoreProtocol:
146142
cdef _auth_password_message_sasl_continue(self, bytes server_response)
147143

148144
cdef _write(self, buf)
145+
cdef _writelines(self, list buffers)
149146

150147
cdef _read_server_messages(self)
151148

@@ -155,9 +152,13 @@ cdef class CoreProtocol:
155152

156153
cdef _ensure_connected(self)
157154

155+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
158156
cdef WriteBuffer _build_bind_message(self, str portal_name,
159157
str stmt_name,
160158
WriteBuffer bind_data)
159+
cdef WriteBuffer _build_empty_bind_data(self)
160+
cdef WriteBuffer _build_execute_message(self, str portal_name,
161+
int32_t limit)
161162

162163

163164
cdef _connect(self)
@@ -166,8 +167,11 @@ cdef class CoreProtocol:
166167
WriteBuffer bind_data, int32_t limit)
167168
cdef _bind_execute(self, str portal_name, str stmt_name,
168169
WriteBuffer bind_data, int32_t limit)
169-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
170-
object bind_data)
170+
cdef _execute_many_init(self)
171+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
172+
object bind_data)
173+
cdef _execute_many_done(self, bint data_sent)
174+
cdef _execute_many_fail(self, object error)
171175
cdef _bind(self, str portal_name, str stmt_name,
172176
WriteBuffer bind_data)
173177
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

+103-51
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ cdef class CoreProtocol:
2727
# type of `scram` is `SCRAMAuthentcation`
2828
self.scram = None
2929

30-
# executemany support data
31-
self._execute_iter = None
32-
self._execute_portal_name = None
33-
self._execute_stmt_name = None
34-
3530
self._reset_result()
3631

32+
cpdef is_in_transaction(self):
33+
# PQTRANS_INTRANS = idle, within transaction block
34+
# PQTRANS_INERROR = idle, within failed transaction
35+
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
36+
3737
cdef _read_server_messages(self):
3838
cdef:
3939
char mtype
@@ -263,22 +263,7 @@ cdef class CoreProtocol:
263263
elif mtype == b'Z':
264264
# ReadyForQuery
265265
self._parse_msg_ready_for_query()
266-
if self.result_type == RESULT_FAILED:
267-
self._push_result()
268-
else:
269-
try:
270-
buf = <WriteBuffer>next(self._execute_iter)
271-
except StopIteration:
272-
self._push_result()
273-
except Exception as e:
274-
self.result_type = RESULT_FAILED
275-
self.result = e
276-
self._push_result()
277-
else:
278-
# Next iteration over the executemany() arg sequence
279-
self._send_bind_message(
280-
self._execute_portal_name, self._execute_stmt_name,
281-
buf, 0)
266+
self._push_result()
282267

283268
elif mtype == b'I':
284269
# EmptyQueryResponse
@@ -780,6 +765,17 @@ cdef class CoreProtocol:
780765
if self.con_status != CONNECTION_OK:
781766
raise apg_exc.InternalClientError('not connected')
782767

768+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
769+
cdef WriteBuffer buf
770+
771+
buf = WriteBuffer.new_message(b'P')
772+
buf.write_str(stmt_name, self.encoding)
773+
buf.write_str(query, self.encoding)
774+
buf.write_int16(0)
775+
776+
buf.end_message()
777+
return buf
778+
783779
cdef WriteBuffer _build_bind_message(self, str portal_name,
784780
str stmt_name,
785781
WriteBuffer bind_data):
@@ -795,6 +791,25 @@ cdef class CoreProtocol:
795791
buf.end_message()
796792
return buf
797793

794+
cdef WriteBuffer _build_empty_bind_data(self):
795+
cdef WriteBuffer buf
796+
buf = WriteBuffer.new()
797+
buf.write_int16(0) # The number of parameter format codes
798+
buf.write_int16(0) # The number of parameter values
799+
buf.write_int16(0) # The number of result-column format codes
800+
return buf
801+
802+
cdef WriteBuffer _build_execute_message(self, str portal_name,
803+
int32_t limit):
804+
cdef WriteBuffer buf
805+
806+
buf = WriteBuffer.new_message(b'E')
807+
buf.write_str(portal_name, self.encoding) # name of the portal
808+
buf.write_int32(limit) # number of rows to return; 0 - all
809+
810+
buf.end_message()
811+
return buf
812+
798813
# API for subclasses
799814

800815
cdef _connect(self):
@@ -845,12 +860,7 @@ cdef class CoreProtocol:
845860
self._ensure_connected()
846861
self._set_state(PROTOCOL_PREPARE)
847862

848-
buf = WriteBuffer.new_message(b'P')
849-
buf.write_str(stmt_name, self.encoding)
850-
buf.write_str(query, self.encoding)
851-
buf.write_int16(0)
852-
buf.end_message()
853-
packet = buf
863+
packet = self._build_parse_message(stmt_name, query)
854864

855865
buf = WriteBuffer.new_message(b'D')
856866
buf.write_byte(b'S')
@@ -872,10 +882,7 @@ cdef class CoreProtocol:
872882
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
873883
packet = buf
874884

875-
buf = WriteBuffer.new_message(b'E')
876-
buf.write_str(portal_name, self.encoding) # name of the portal
877-
buf.write_int32(limit) # number of rows to return; 0 - all
878-
buf.end_message()
885+
buf = self._build_execute_message(portal_name, limit)
879886
packet.write_buffer(buf)
880887

881888
packet.write_bytes(SYNC_MESSAGE)
@@ -894,30 +901,75 @@ cdef class CoreProtocol:
894901

895902
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
896903

897-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
898-
object bind_data):
899-
900-
cdef WriteBuffer buf
901-
904+
cdef _execute_many_init(self):
902905
self._ensure_connected()
903906
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
904907

905908
self.result = None
906909
self._discard_data = True
907-
self._execute_iter = bind_data
908-
self._execute_portal_name = portal_name
909-
self._execute_stmt_name = stmt_name
910910

911-
try:
912-
buf = <WriteBuffer>next(bind_data)
913-
except StopIteration:
914-
self._push_result()
915-
except Exception as e:
916-
self.result_type = RESULT_FAILED
917-
self.result = e
911+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
912+
object bind_data):
913+
cdef:
914+
WriteBuffer packet
915+
WriteBuffer buf
916+
list buffers = []
917+
918+
if self.result_type == RESULT_FAILED:
919+
raise StopIteration(True)
920+
921+
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
922+
packet = WriteBuffer.new()
923+
924+
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
925+
try:
926+
buf = <WriteBuffer>next(bind_data)
927+
except StopIteration:
928+
if packet.len() > 0:
929+
buffers.append(packet)
930+
if len(buffers) > 0:
931+
self._writelines(buffers)
932+
raise StopIteration(True)
933+
else:
934+
raise StopIteration(False)
935+
except Exception as ex:
936+
raise StopIteration(ex)
937+
packet.write_buffer(
938+
self._build_bind_message(portal_name, stmt_name, buf))
939+
packet.write_buffer(
940+
self._build_execute_message(portal_name, 0))
941+
buffers.append(packet)
942+
self._writelines(buffers)
943+
944+
cdef _execute_many_done(self, bint data_sent):
945+
if data_sent:
946+
self._write(SYNC_MESSAGE)
947+
else:
918948
self._push_result()
949+
950+
cdef _execute_many_fail(self, object error):
951+
cdef WriteBuffer buf
952+
953+
self.result_type = RESULT_FAILED
954+
self.result = error
955+
956+
# We shall rollback in an implicit transaction to prevent partial
957+
# commit, while do nothing in an explicit transaction and leaving the
958+
# error to the user
959+
if self.is_in_transaction():
960+
self._execute_many_done(True)
919961
else:
920-
self._send_bind_message(portal_name, stmt_name, buf, 0)
962+
# Here if the implicit transaction is in `ignore_till_sync` mode,
963+
# the `ROLLBACK` will be ignored and `Sync` will restore the state;
964+
# or else the implicit transaction will be rolled back with a
965+
# warning saying that there was no transaction, but rollback is
966+
# done anyway, so we could ignore this warning.
967+
buf = self._build_parse_message('', 'ROLLBACK')
968+
buf.write_buffer(self._build_bind_message(
969+
'', '', self._build_empty_bind_data()))
970+
buf.write_buffer(self._build_execute_message('', 0))
971+
buf.write_bytes(SYNC_MESSAGE)
972+
self._write(buf)
921973

922974
cdef _execute(self, str portal_name, int32_t limit):
923975
cdef WriteBuffer buf
@@ -927,10 +979,7 @@ cdef class CoreProtocol:
927979

928980
self.result = []
929981

930-
buf = WriteBuffer.new_message(b'E')
931-
buf.write_str(portal_name, self.encoding) # name of the portal
932-
buf.write_int32(limit) # number of rows to return; 0 - all
933-
buf.end_message()
982+
buf = self._build_execute_message(portal_name, limit)
934983

935984
buf.write_bytes(SYNC_MESSAGE)
936985

@@ -1013,6 +1062,9 @@ cdef class CoreProtocol:
10131062
cdef _write(self, buf):
10141063
raise NotImplementedError
10151064

1065+
cdef _writelines(self, list buffers):
1066+
raise NotImplementedError
1067+
10161068
cdef _decode_row(self, const char* buf, ssize_t buf_len):
10171069
pass
10181070

0 commit comments

Comments
 (0)