Skip to content

Commit 557f515

Browse files
committed
Rewrite executemany() to batch args for performance.
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take advantage of `writelines()`. A single `Sync` is sent at last, so that all args live in the same transaction. Closes: #289
1 parent 92aa806 commit 557f515

File tree

7 files changed

+330
-97
lines changed

7 files changed

+330
-97
lines changed

asyncpg/connection.py

+10
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
331331
332332
.. versionchanged:: 0.11.0
333333
`timeout` became a keyword-only parameter.
334+
335+
.. versionchanged:: 0.22.0
336+
The execution was changed to be in an implicit transaction if there
337+
was no explicit transaction, so that it will no longer end up with
338+
partial success. If you still need the previous behavior to
339+
progressively execute many args, please use a loop with prepared
340+
statement instead.
334341
"""
335342
self._check_open()
336343
return await self._executemany(command, args, timeout)
@@ -1010,6 +1017,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
10101017
f = source
10111018
elif isinstance(source, collections.abc.AsyncIterable):
10121019
# assuming calling output returns an awaitable.
1020+
# copy_in() is designed to handle very large amounts of data, and
1021+
# the source async iterable is allowed to return an arbitrary
1022+
# amount of data on every iteration.
10131023
reader = source
10141024
else:
10151025
# assuming source is an instance supporting the buffer protocol.

asyncpg/prepared_stmt.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,24 @@ async def fetchrow(self, *args, timeout=None):
202202
return None
203203
return data[0]
204204

205-
async def __bind_execute(self, args, limit, timeout):
205+
@connresource.guarded
206+
async def executemany(self, args, *, timeout: float=None):
207+
"""Execute the statement for each sequence of arguments in *args*.
208+
209+
:param args: An iterable containing sequences of arguments.
210+
:param float timeout: Optional timeout value in seconds.
211+
:return None: This method discards the results of the operations.
212+
213+
.. versionadded:: 0.22.0
214+
"""
215+
return await self.__do_execute(
216+
lambda protocol: protocol.bind_execute_many(
217+
self._state, args, '', timeout))
218+
219+
async def __do_execute(self, executor):
206220
protocol = self._connection._protocol
207221
try:
208-
data, status, _ = await protocol.bind_execute(
209-
self._state, args, '', limit, True, timeout)
222+
return await executor(protocol)
210223
except exceptions.OutdatedSchemaCacheError:
211224
await self._connection.reload_schema_state()
212225
# We can not find all manually created prepared statements, so just
@@ -215,6 +228,11 @@ async def __bind_execute(self, args, limit, timeout):
215228
# invalidate themselves (unfortunately, clearing caches again).
216229
self._state.mark_closed()
217230
raise
231+
232+
async def __bind_execute(self, args, limit, timeout):
233+
data, status, _ = await self.__do_execute(
234+
lambda protocol: protocol.bind_execute(
235+
self._state, args, '', limit, True, timeout))
218236
self._last_status = status
219237
return data
220238

asyncpg/protocol/consts.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
DEF _MAXINT32 = 2**31 - 1
99
DEF _COPY_BUFFER_SIZE = 524288
1010
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
11+
DEF _EXECUTE_MANY_BUF_NUM = 4
12+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

+9-2
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ cdef class CoreProtocol:
114114
# True - completed, False - suspended
115115
bint result_execute_completed
116116

117+
cpdef is_in_transaction(self)
117118
cdef _process__auth(self, char mtype)
118119
cdef _process__prepare(self, char mtype)
119120
cdef _process__bind_execute(self, char mtype)
@@ -146,6 +147,7 @@ cdef class CoreProtocol:
146147
cdef _auth_password_message_sasl_continue(self, bytes server_response)
147148

148149
cdef _write(self, buf)
150+
cdef _writelines(self, list buffers)
149151

150152
cdef _read_server_messages(self)
151153

@@ -155,9 +157,13 @@ cdef class CoreProtocol:
155157

156158
cdef _ensure_connected(self)
157159

160+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
158161
cdef WriteBuffer _build_bind_message(self, str portal_name,
159162
str stmt_name,
160163
WriteBuffer bind_data)
164+
cdef WriteBuffer _build_empty_bind_data(self)
165+
cdef WriteBuffer _build_execute_message(self, str portal_name,
166+
int32_t limit)
161167

162168

163169
cdef _connect(self)
@@ -166,8 +172,9 @@ cdef class CoreProtocol:
166172
WriteBuffer bind_data, int32_t limit)
167173
cdef _bind_execute(self, str portal_name, str stmt_name,
168174
WriteBuffer bind_data, int32_t limit)
169-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
170-
object bind_data)
175+
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
176+
object bind_data)
177+
cdef bint _bind_execute_many_more(self, bint first=*)
171178
cdef _bind(self, str portal_name, str stmt_name,
172179
WriteBuffer bind_data)
173180
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

+129-45
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ cdef class CoreProtocol:
3434

3535
self._reset_result()
3636

37+
cpdef is_in_transaction(self):
38+
# PQTRANS_INTRANS = idle, within transaction block
39+
# PQTRANS_INERROR = idle, within failed transaction
40+
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
41+
3742
cdef _read_server_messages(self):
3843
cdef:
3944
char mtype
@@ -263,27 +268,16 @@ cdef class CoreProtocol:
263268
elif mtype == b'Z':
264269
# ReadyForQuery
265270
self._parse_msg_ready_for_query()
266-
if self.result_type == RESULT_FAILED:
267-
self._push_result()
268-
else:
269-
try:
270-
buf = <WriteBuffer>next(self._execute_iter)
271-
except StopIteration:
272-
self._push_result()
273-
except Exception as e:
274-
self.result_type = RESULT_FAILED
275-
self.result = e
276-
self._push_result()
277-
else:
278-
# Next iteration over the executemany() arg sequence
279-
self._send_bind_message(
280-
self._execute_portal_name, self._execute_stmt_name,
281-
buf, 0)
271+
self._push_result()
282272

283273
elif mtype == b'I':
284274
# EmptyQueryResponse
285275
self.buffer.discard_message()
286276

277+
elif mtype == b'1':
278+
# ParseComplete
279+
self.buffer.discard_message()
280+
287281
cdef _process__bind(self, char mtype):
288282
if mtype == b'E':
289283
# ErrorResponse
@@ -780,6 +774,17 @@ cdef class CoreProtocol:
780774
if self.con_status != CONNECTION_OK:
781775
raise apg_exc.InternalClientError('not connected')
782776

777+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
778+
cdef WriteBuffer buf
779+
780+
buf = WriteBuffer.new_message(b'P')
781+
buf.write_str(stmt_name, self.encoding)
782+
buf.write_str(query, self.encoding)
783+
buf.write_int16(0)
784+
785+
buf.end_message()
786+
return buf
787+
783788
cdef WriteBuffer _build_bind_message(self, str portal_name,
784789
str stmt_name,
785790
WriteBuffer bind_data):
@@ -795,6 +800,25 @@ cdef class CoreProtocol:
795800
buf.end_message()
796801
return buf
797802

803+
cdef WriteBuffer _build_empty_bind_data(self):
804+
cdef WriteBuffer buf
805+
buf = WriteBuffer.new()
806+
buf.write_int16(0) # The number of parameter format codes
807+
buf.write_int16(0) # The number of parameter values
808+
buf.write_int16(0) # The number of result-column format codes
809+
return buf
810+
811+
cdef WriteBuffer _build_execute_message(self, str portal_name,
812+
int32_t limit):
813+
cdef WriteBuffer buf
814+
815+
buf = WriteBuffer.new_message(b'E')
816+
buf.write_str(portal_name, self.encoding) # name of the portal
817+
buf.write_int32(limit) # number of rows to return; 0 - all
818+
819+
buf.end_message()
820+
return buf
821+
798822
# API for subclasses
799823

800824
cdef _connect(self):
@@ -845,12 +869,7 @@ cdef class CoreProtocol:
845869
self._ensure_connected()
846870
self._set_state(PROTOCOL_PREPARE)
847871

848-
buf = WriteBuffer.new_message(b'P')
849-
buf.write_str(stmt_name, self.encoding)
850-
buf.write_str(query, self.encoding)
851-
buf.write_int16(0)
852-
buf.end_message()
853-
packet = buf
872+
packet = self._build_parse_message(stmt_name, query)
854873

855874
buf = WriteBuffer.new_message(b'D')
856875
buf.write_byte(b'S')
@@ -872,10 +891,7 @@ cdef class CoreProtocol:
872891
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
873892
packet = buf
874893

875-
buf = WriteBuffer.new_message(b'E')
876-
buf.write_str(portal_name, self.encoding) # name of the portal
877-
buf.write_int32(limit) # number of rows to return; 0 - all
878-
buf.end_message()
894+
buf = self._build_execute_message(portal_name, limit)
879895
packet.write_buffer(buf)
880896

881897
packet.write_bytes(SYNC_MESSAGE)
@@ -894,11 +910,8 @@ cdef class CoreProtocol:
894910

895911
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
896912

897-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
898-
object bind_data):
899-
900-
cdef WriteBuffer buf
901-
913+
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
914+
object bind_data):
902915
self._ensure_connected()
903916
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
904917

@@ -907,17 +920,88 @@ cdef class CoreProtocol:
907920
self._execute_iter = bind_data
908921
self._execute_portal_name = portal_name
909922
self._execute_stmt_name = stmt_name
923+
return self._bind_execute_many_more(True)
910924

911-
try:
912-
buf = <WriteBuffer>next(bind_data)
913-
except StopIteration:
914-
self._push_result()
915-
except Exception as e:
916-
self.result_type = RESULT_FAILED
917-
self.result = e
918-
self._push_result()
919-
else:
920-
self._send_bind_message(portal_name, stmt_name, buf, 0)
925+
cdef bint _bind_execute_many_more(self, bint first=False):
926+
cdef:
927+
WriteBuffer packet
928+
WriteBuffer buf
929+
list buffers = []
930+
931+
# as we keep sending, the server may return an error early
932+
if self.result_type == RESULT_FAILED:
933+
self._write(SYNC_MESSAGE)
934+
return False
935+
936+
# collect up to four 32KB buffers to send
937+
# https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
938+
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
939+
packet = WriteBuffer.new()
940+
941+
# fill one 32KB buffer
942+
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
943+
try:
944+
# grab one item from the input
945+
buf = <WriteBuffer>next(self._execute_iter)
946+
947+
# reached the end of the input
948+
except StopIteration:
949+
if first:
950+
# if we never send anything, simply set the result
951+
self._push_result()
952+
else:
953+
# otherwise, append SYNC and send the buffers
954+
packet.write_bytes(SYNC_MESSAGE)
955+
buffers.append(packet)
956+
self._writelines(buffers)
957+
return False
958+
959+
# error in input, give up the buffers and cleanup
960+
except Exception as ex:
961+
self.result_type = RESULT_FAILED
962+
self.result = ex
963+
if first:
964+
self._push_result()
965+
elif self.is_in_transaction():
966+
# we're in an explicit transaction, just SYNC
967+
self._write(SYNC_MESSAGE)
968+
else:
969+
# In an implicit transaction, if `ignore_till_sync`,
970+
# `ROLLBACK` will be ignored and `Sync` will restore
971+
# the state; or the transaction will be rolled back
972+
# with a warning saying that there was no transaction,
973+
# but rollback is done anyway, so we could safely
974+
# ignore this warning.
975+
# GOTCHA: simple query message will be ignored if
976+
# `ignore_till_sync` is set.
977+
buf = self._build_parse_message('', 'ROLLBACK')
978+
buf.write_buffer(self._build_bind_message(
979+
'', '', self._build_empty_bind_data()))
980+
buf.write_buffer(self._build_execute_message('', 0))
981+
buf.write_bytes(SYNC_MESSAGE)
982+
self._write(buf)
983+
return False
984+
985+
# all good, write to the buffer
986+
first = False
987+
packet.write_buffer(
988+
self._build_bind_message(
989+
self._execute_portal_name,
990+
self._execute_stmt_name,
991+
buf,
992+
)
993+
)
994+
packet.write_buffer(
995+
self._build_execute_message(self._execute_portal_name, 0,
996+
)
997+
)
998+
999+
# collected one buffer
1000+
buffers.append(packet)
1001+
1002+
# write to the wire, and signal the caller for more to send
1003+
self._writelines(buffers)
1004+
return True
9211005

9221006
cdef _execute(self, str portal_name, int32_t limit):
9231007
cdef WriteBuffer buf
@@ -927,10 +1011,7 @@ cdef class CoreProtocol:
9271011

9281012
self.result = []
9291013

930-
buf = WriteBuffer.new_message(b'E')
931-
buf.write_str(portal_name, self.encoding) # name of the portal
932-
buf.write_int32(limit) # number of rows to return; 0 - all
933-
buf.end_message()
1014+
buf = self._build_execute_message(portal_name, limit)
9341015

9351016
buf.write_bytes(SYNC_MESSAGE)
9361017

@@ -1013,6 +1094,9 @@ cdef class CoreProtocol:
10131094
cdef _write(self, buf):
10141095
raise NotImplementedError
10151096

1097+
cdef _writelines(self, list buffers):
1098+
raise NotImplementedError
1099+
10161100
cdef _decode_row(self, const char* buf, ssize_t buf_len):
10171101
pass
10181102

0 commit comments

Comments
 (0)