Skip to content

Commit 19a2798

Browse files
committed
CRF MagicStack#295, use simple query ROLLBACK in implicit transaction only
1 parent 5c4f50a commit 19a2798

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

asyncpg/protocol/coreproto.pyx

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -810,34 +810,29 @@ cdef class CoreProtocol:
810810
self.transport.writelines(buffers)
811811

812812
cdef _execute_many_rollback(self, object error):
813-
cdef:
814-
WriteBuffer packet
815-
WriteBuffer buf
813+
cdef WriteBuffer buf
816814

817815
self._ensure_connected()
818816
if self.state != PROTOCOL_CANCELLED:
819817
self._set_state(PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK)
820-
packet = WriteBuffer.new()
821-
822-
# We have no idea if we are in an explicit transaction or not,
823-
# therefore we raise an exception in the database to mark current
824-
# transaction as failed anyway to prevent partial commit
825-
buf = self._build_parse_message('',
826-
'DO language plpgsql $$ BEGIN '
827-
'RAISE transaction_rollback; '
828-
'END$$;')
829-
packet.write_buffer(buf)
830-
buf = self._build_bind_message('', '', WriteBuffer.new())
831-
packet.write_buffer(buf)
832-
buf = self._build_execute_message('', 0)
833-
packet.write_buffer(buf)
834-
packet.write_bytes(SYNC_MESSAGE)
835-
836-
self.transport.write(memoryview(packet))
837818

838819
self.result_type = RESULT_FAILED
839820
self.result = error
840821

822+
# We shall rollback in an implicit transaction to prevent partial
823+
# commit, while do nothing in an explicit transaction and leaving the
824+
# error to the user
825+
if self.xact_status == PQTRANS_IDLE:
826+
buf = WriteBuffer.new_message(b'Q')
827+
# ROLLBACK here won't cause server to send RowDescription,
828+
# CopyInResponse or CopyOutResponse which we are not expecting, but
829+
# the server will send ReadyForQuery which finishes executemany()
830+
buf.write_str('ROLLBACK;', self.encoding)
831+
buf.end_message()
832+
self._write(buf)
833+
else:
834+
self._write_sync_message()
835+
841836
cdef _execute_many_end(self, object sync):
842837
if sync is True:
843838
self._write_sync_message()

tests/test_execute.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,23 @@ async def test_execute_many_rollback(self):
231231
''', AsyncIterable(sleep=0.3, sleep_start=32), timeout=0.5)
232232
result = await self.con.fetch('SELECT * FROM exmany')
233233
self.assertEqual(result, [])
234+
235+
async def test_execute_many_rollback_in_transaction(self):
236+
# User error raised inside the transaction block
237+
with self.assertRaises(ZeroDivisionError):
238+
async with self.con.transaction():
239+
await self.con.executemany('''
240+
INSERT INTO exmany VALUES($1, $2)
241+
''', (('a' * 32768, y + y / y) for y in range(10, -1, -1)))
242+
result = await self.con.fetch('SELECT * FROM exmany')
243+
self.assertEqual(result, [])
244+
245+
# User error captured inside the transaction block, executemany() is
246+
# partially committed (4 * 2 buffers sent, [3, 2] not sent)
247+
async with self.con.transaction():
248+
with self.assertRaises(ZeroDivisionError):
249+
await self.con.executemany('''
250+
INSERT INTO exmany VALUES($1, $2)
251+
''', (('a' * 32768, y + y / y) for y in range(10, -1, -1)))
252+
result = await self.con.fetch('SELECT * FROM exmany')
253+
self.assertEqual([x[1] for x in result], [11, 10, 9, 8, 7, 6, 5, 4])

0 commit comments

Comments
 (0)