Skip to content

Commit 46ecc07

Browse files
committed
Rewrite (again) executemany() to batch args for performance.
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take advantage of `writelines()`. A single `Sync` is sent at last, so that all args live in the same transaction. Closes: #289
1 parent 43a7b21 commit 46ecc07

File tree

7 files changed

+321
-112
lines changed

7 files changed

+321
-112
lines changed

asyncpg/connection.py

+10
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
301301
302302
.. versionchanged:: 0.11.0
303303
`timeout` became a keyword-only parameter.
304+
305+
.. versionchanged:: 0.19.0
306+
The execution was changed to be in an implicit transaction if there
307+
was no explicit transaction, so that it will no longer end up with
308+
partial success. If you still need the previous behavior to
309+
progressively execute many args, please use a loop with prepared
310+
statement instead.
304311
"""
305312
self._check_open()
306313
return await self._executemany(command, args, timeout)
@@ -821,6 +828,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
821828
f = source
822829
elif isinstance(source, collections.abc.AsyncIterable):
823830
# assuming calling output returns an awaitable.
831+
# copy_in() is designed to handle very large amounts of data, and
832+
# the source async iterable is allowed to return an arbitrary
833+
# amount of data on every iteration.
824834
reader = source
825835
else:
826836
# assuming source is an instance supporting the buffer protocol.

asyncpg/prepared_stmt.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,24 @@ async def fetchrow(self, *args, timeout=None):
196196
return None
197197
return data[0]
198198

199-
async def __bind_execute(self, args, limit, timeout):
199+
@connresource.guarded
200+
async def executemany(self, args, *, timeout: float=None):
201+
"""Execute the statement for each sequence of arguments in *args*.
202+
203+
:param args: An iterable containing sequences of arguments.
204+
:param float timeout: Optional timeout value in seconds.
205+
:return None: This method discards the results of the operations.
206+
207+
.. versionadded:: 0.19.0
208+
"""
209+
return await self.__do_execute(
210+
lambda protocol: protocol.bind_execute_many(
211+
self._state, args, '', timeout))
212+
213+
async def __do_execute(self, executor):
200214
protocol = self._connection._protocol
201215
try:
202-
data, status, _ = await protocol.bind_execute(
203-
self._state, args, '', limit, True, timeout)
216+
return await executor(protocol)
204217
except exceptions.OutdatedSchemaCacheError:
205218
await self._connection.reload_schema_state()
206219
# We can not find all manually created prepared statements, so just
@@ -209,6 +222,11 @@ async def __bind_execute(self, args, limit, timeout):
209222
# invalidate themselves (unfortunately, clearing caches again).
210223
self._state.mark_closed()
211224
raise
225+
226+
async def __bind_execute(self, args, limit, timeout):
227+
data, status, _ = await self.__do_execute(
228+
lambda protocol: protocol.bind_execute(
229+
self._state, args, '', limit, True, timeout))
212230
self._last_status = status
213231
return data
214232

asyncpg/protocol/consts.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
DEF _MAXINT32 = 2**31 - 1
99
DEF _COPY_BUFFER_SIZE = 524288
1010
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
11+
DEF _EXECUTE_MANY_BUF_NUM = 4
12+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

+11-7
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ cdef class CoreProtocol:
7676
bint _skip_discard
7777
bint _discard_data
7878

79-
# executemany support data
80-
object _execute_iter
81-
str _execute_portal_name
82-
str _execute_stmt_name
83-
8479
ConnectionStatus con_status
8580
ProtocolState state
8681
TransactionStatus xact_status
@@ -105,6 +100,7 @@ cdef class CoreProtocol:
105100
# True - completed, False - suspended
106101
bint result_execute_completed
107102

103+
cpdef is_in_transaction(self)
108104
cdef _process__auth(self, char mtype)
109105
cdef _process__prepare(self, char mtype)
110106
cdef _process__bind_execute(self, char mtype)
@@ -135,6 +131,7 @@ cdef class CoreProtocol:
135131
cdef _auth_password_message_md5(self, bytes salt)
136132

137133
cdef _write(self, buf)
134+
cdef _writelines(self, list buffers)
138135

139136
cdef _read_server_messages(self)
140137

@@ -144,9 +141,13 @@ cdef class CoreProtocol:
144141

145142
cdef _ensure_connected(self)
146143

144+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
147145
cdef WriteBuffer _build_bind_message(self, str portal_name,
148146
str stmt_name,
149147
WriteBuffer bind_data)
148+
cdef WriteBuffer _build_empty_bind_data(self)
149+
cdef WriteBuffer _build_execute_message(self, str portal_name,
150+
int32_t limit)
150151

151152

152153
cdef _connect(self)
@@ -155,8 +156,11 @@ cdef class CoreProtocol:
155156
WriteBuffer bind_data, int32_t limit)
156157
cdef _bind_execute(self, str portal_name, str stmt_name,
157158
WriteBuffer bind_data, int32_t limit)
158-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
159-
object bind_data)
159+
cdef _execute_many_init(self)
160+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
161+
object bind_data)
162+
cdef _execute_many_done(self, bint data_sent)
163+
cdef _execute_many_fail(self, object error)
160164
cdef _bind(self, str portal_name, str stmt_name,
161165
WriteBuffer bind_data)
162166
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

+103-51
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ cdef class CoreProtocol:
2222
self.xact_status = PQTRANS_IDLE
2323
self.encoding = 'utf-8'
2424

25-
# executemany support data
26-
self._execute_iter = None
27-
self._execute_portal_name = None
28-
self._execute_stmt_name = None
29-
3025
self._reset_result()
3126

27+
cpdef is_in_transaction(self):
28+
# PQTRANS_INTRANS = idle, within transaction block
29+
# PQTRANS_INERROR = idle, within failed transaction
30+
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
31+
3232
cdef _read_server_messages(self):
3333
cdef:
3434
char mtype
@@ -253,22 +253,7 @@ cdef class CoreProtocol:
253253
elif mtype == b'Z':
254254
# ReadyForQuery
255255
self._parse_msg_ready_for_query()
256-
if self.result_type == RESULT_FAILED:
257-
self._push_result()
258-
else:
259-
try:
260-
buf = <WriteBuffer>next(self._execute_iter)
261-
except StopIteration:
262-
self._push_result()
263-
except Exception as e:
264-
self.result_type = RESULT_FAILED
265-
self.result = e
266-
self._push_result()
267-
else:
268-
# Next iteration over the executemany() arg sequence
269-
self._send_bind_message(
270-
self._execute_portal_name, self._execute_stmt_name,
271-
buf, 0)
256+
self._push_result()
272257

273258
elif mtype == b'I':
274259
# EmptyQueryResponse
@@ -687,6 +672,17 @@ cdef class CoreProtocol:
687672
if self.con_status != CONNECTION_OK:
688673
raise apg_exc.InternalClientError('not connected')
689674

675+
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
676+
cdef WriteBuffer buf
677+
678+
buf = WriteBuffer.new_message(b'P')
679+
buf.write_str(stmt_name, self.encoding)
680+
buf.write_str(query, self.encoding)
681+
buf.write_int16(0)
682+
683+
buf.end_message()
684+
return buf
685+
690686
cdef WriteBuffer _build_bind_message(self, str portal_name,
691687
str stmt_name,
692688
WriteBuffer bind_data):
@@ -702,6 +698,25 @@ cdef class CoreProtocol:
702698
buf.end_message()
703699
return buf
704700

701+
cdef WriteBuffer _build_empty_bind_data(self):
702+
cdef WriteBuffer buf
703+
buf = WriteBuffer.new()
704+
buf.write_int16(0) # The number of parameter format codes
705+
buf.write_int16(0) # The number of parameter values
706+
buf.write_int16(0) # The number of result-column format codes
707+
return buf
708+
709+
cdef WriteBuffer _build_execute_message(self, str portal_name,
710+
int32_t limit):
711+
cdef WriteBuffer buf
712+
713+
buf = WriteBuffer.new_message(b'E')
714+
buf.write_str(portal_name, self.encoding) # name of the portal
715+
buf.write_int32(limit) # number of rows to return; 0 - all
716+
717+
buf.end_message()
718+
return buf
719+
705720
# API for subclasses
706721

707722
cdef _connect(self):
@@ -752,12 +767,7 @@ cdef class CoreProtocol:
752767
self._ensure_connected()
753768
self._set_state(PROTOCOL_PREPARE)
754769

755-
buf = WriteBuffer.new_message(b'P')
756-
buf.write_str(stmt_name, self.encoding)
757-
buf.write_str(query, self.encoding)
758-
buf.write_int16(0)
759-
buf.end_message()
760-
packet = buf
770+
packet = self._build_parse_message(stmt_name, query)
761771

762772
buf = WriteBuffer.new_message(b'D')
763773
buf.write_byte(b'S')
@@ -779,10 +789,7 @@ cdef class CoreProtocol:
779789
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
780790
packet = buf
781791

782-
buf = WriteBuffer.new_message(b'E')
783-
buf.write_str(portal_name, self.encoding) # name of the portal
784-
buf.write_int32(limit) # number of rows to return; 0 - all
785-
buf.end_message()
792+
buf = self._build_execute_message(portal_name, limit)
786793
packet.write_buffer(buf)
787794

788795
packet.write_bytes(SYNC_MESSAGE)
@@ -801,30 +808,75 @@ cdef class CoreProtocol:
801808

802809
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
803810

804-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
805-
object bind_data):
806-
807-
cdef WriteBuffer buf
808-
811+
cdef _execute_many_init(self):
809812
self._ensure_connected()
810813
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
811814

812815
self.result = None
813816
self._discard_data = True
814-
self._execute_iter = bind_data
815-
self._execute_portal_name = portal_name
816-
self._execute_stmt_name = stmt_name
817817

818-
try:
819-
buf = <WriteBuffer>next(bind_data)
820-
except StopIteration:
821-
self._push_result()
822-
except Exception as e:
823-
self.result_type = RESULT_FAILED
824-
self.result = e
818+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
819+
object bind_data):
820+
cdef:
821+
WriteBuffer packet
822+
WriteBuffer buf
823+
list buffers = []
824+
825+
if self.result_type == RESULT_FAILED:
826+
raise StopIteration(True)
827+
828+
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
829+
packet = WriteBuffer.new()
830+
831+
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
832+
try:
833+
buf = <WriteBuffer>next(bind_data)
834+
except StopIteration:
835+
if packet.len() > 0:
836+
buffers.append(packet)
837+
if len(buffers) > 0:
838+
self._writelines(buffers)
839+
raise StopIteration(True)
840+
else:
841+
raise StopIteration(False)
842+
except Exception as ex:
843+
raise StopIteration(ex)
844+
packet.write_buffer(
845+
self._build_bind_message(portal_name, stmt_name, buf))
846+
packet.write_buffer(
847+
self._build_execute_message(portal_name, 0))
848+
buffers.append(packet)
849+
self._writelines(buffers)
850+
851+
cdef _execute_many_done(self, bint data_sent):
852+
if data_sent:
853+
self._write(SYNC_MESSAGE)
854+
else:
825855
self._push_result()
856+
857+
cdef _execute_many_fail(self, object error):
858+
cdef WriteBuffer buf
859+
860+
self.result_type = RESULT_FAILED
861+
self.result = error
862+
863+
# We shall rollback in an implicit transaction to prevent partial
864+
# commit, while do nothing in an explicit transaction and leaving the
865+
# error to the user
866+
if self.is_in_transaction():
867+
self._execute_many_done(True)
826868
else:
827-
self._send_bind_message(portal_name, stmt_name, buf, 0)
869+
# Here if the implicit transaction is in `ignore_till_sync` mode,
870+
# the `ROLLBACK` will be ignored and `Sync` will restore the state;
871+
# or else the implicit transaction will be rolled back with a
872+
# warning saying that there was no transaction, but rollback is
873+
# done anyway, so we could ignore this warning.
874+
buf = self._build_parse_message('', 'ROLLBACK')
875+
buf.write_buffer(self._build_bind_message(
876+
'', '', self._build_empty_bind_data()))
877+
buf.write_buffer(self._build_execute_message('', 0))
878+
buf.write_bytes(SYNC_MESSAGE)
879+
self._write(buf)
828880

829881
cdef _execute(self, str portal_name, int32_t limit):
830882
cdef WriteBuffer buf
@@ -834,10 +886,7 @@ cdef class CoreProtocol:
834886

835887
self.result = []
836888

837-
buf = WriteBuffer.new_message(b'E')
838-
buf.write_str(portal_name, self.encoding) # name of the portal
839-
buf.write_int32(limit) # number of rows to return; 0 - all
840-
buf.end_message()
889+
buf = self._build_execute_message(portal_name, limit)
841890

842891
buf.write_bytes(SYNC_MESSAGE)
843892

@@ -920,6 +969,9 @@ cdef class CoreProtocol:
920969
cdef _write(self, buf):
921970
raise NotImplementedError
922971

972+
cdef _writelines(self, list buffers):
973+
raise NotImplementedError
974+
923975
cdef _decode_row(self, const char* buf, ssize_t buf_len):
924976
pass
925977

0 commit comments

Comments
 (0)