Skip to content

Commit b137184

Browse files
elpransfantix
authored andcommitted
Keep track of timeout in executemany properly
1 parent c4c5b5e commit b137184

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

asyncpg/protocol/protocol.pyx

+20-4
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ cdef class BaseProtocol(CoreProtocol):
210210

211211
self._check_state()
212212
timeout = self._get_timeout_impl(timeout)
213+
timer = Timer(timeout)
213214

214215
# Make sure the argument sequence is encoded lazily with
215216
# this generator expression to keep the memory pressure under
@@ -230,12 +231,20 @@ cdef class BaseProtocol(CoreProtocol):
230231
self.queries_count += 1
231232

232233
while more:
233-
await self.writing_allowed.wait()
234-
# On Windows the above event somehow won't allow context
235-
# switch, so forcing one with sleep(0) here
236-
await asyncio.sleep(0)
234+
with timer:
235+
await asyncio.wait_for(
236+
self.writing_allowed.wait(),
237+
timeout=timer.get_remaining_budget())
238+
# On Windows the above event somehow won't allow context
239+
# switch, so forcing one with sleep(0) here
240+
await asyncio.sleep(0)
241+
if not timer.has_budget_greater_than(0):
242+
raise asyncio.TimeoutError
237243
more = self._bind_execute_many_more() # network op
238244

245+
except asyncio.TimeoutError as e:
246+
self._bind_execute_many_fail(e) # network op
247+
239248
except Exception as ex:
240249
waiter.set_exception(ex)
241250
self._coreproto_error()
@@ -951,6 +960,13 @@ class Timer:
951960
def get_remaining_budget(self):
952961
return self._budget
953962

963+
def has_budget_greater_than(self, amount):
964+
if self._budget is None:
965+
# Unlimited budget.
966+
return True
967+
else:
968+
return self._budget > amount
969+
954970

955971
class Protocol(BaseProtocol, asyncio.Protocol):
956972
pass

tests/test_execute.py

+44-13
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ class TestExecuteMany(tb.ConnectedTestCase):
103103
def setUp(self):
104104
super().setUp()
105105
self.loop.run_until_complete(self.con.execute(
106-
'CREATE TEMP TABLE exmany (a text, b int PRIMARY KEY)'))
106+
'CREATE TABLE exmany (a text, b int PRIMARY KEY)'))
107107

108108
def tearDown(self):
109109
self.loop.run_until_complete(self.con.execute('DROP TABLE exmany'))
110110
super().tearDown()
111111

112-
async def test_basic(self):
112+
async def test_executemany_basic(self):
113113
result = await self.con.executemany('''
114114
INSERT INTO exmany VALUES($1, $2)
115115
''', [
@@ -139,7 +139,7 @@ async def test_basic(self):
139139
('a', 1), ('b', 2), ('c', 3), ('d', 4)
140140
])
141141

142-
async def test_bad_input(self):
142+
async def test_executemany_bad_input(self):
143143
bad_data = ([1 / 0] for v in range(10))
144144

145145
with self.assertRaises(ZeroDivisionError):
@@ -154,7 +154,7 @@ async def test_bad_input(self):
154154
INSERT INTO exmany (b)VALUES($1)
155155
''', good_data)
156156

157-
async def test_server_failure(self):
157+
async def test_executemany_server_failure(self):
158158
with self.assertRaises(UniqueViolationError):
159159
await self.con.executemany('''
160160
INSERT INTO exmany VALUES($1, $2)
@@ -164,7 +164,7 @@ async def test_server_failure(self):
164164
result = await self.con.fetch('SELECT * FROM exmany')
165165
self.assertEqual(result, [])
166166

167-
async def test_server_failure_after_writes(self):
167+
async def test_executemany_server_failure_after_writes(self):
168168
with self.assertRaises(UniqueViolationError):
169169
await self.con.executemany('''
170170
INSERT INTO exmany VALUES($1, $2)
@@ -174,7 +174,7 @@ async def test_server_failure_after_writes(self):
174174
result = await self.con.fetch('SELECT b FROM exmany')
175175
self.assertEqual(result, [])
176176

177-
async def test_server_failure_during_writes(self):
177+
async def test_executemany_server_failure_during_writes(self):
178178
# failure at the beginning, server error detected in the middle
179179
pos = 0
180180

@@ -195,23 +195,54 @@ def gen():
195195
self.assertEqual(result, [])
196196
self.assertLess(pos, 128, 'should stop early')
197197

198-
async def test_client_failure_after_writes(self):
198+
async def test_executemany_client_failure_after_writes(self):
199199
with self.assertRaises(ZeroDivisionError):
200200
await self.con.executemany('''
201201
INSERT INTO exmany VALUES($1, $2)
202202
''', (('a' * 32768, y + y / y) for y in range(10, -1, -1)))
203203
result = await self.con.fetch('SELECT b FROM exmany')
204204
self.assertEqual(result, [])
205205

206-
async def test_timeout(self):
206+
async def test_executemany_timeout(self):
207207
with self.assertRaises(asyncio.TimeoutError):
208208
await self.con.executemany('''
209-
INSERT INTO exmany VALUES(pg_sleep(0.1), $1)
210-
''', [[x] for x in range(128)], timeout=0.5)
209+
INSERT INTO exmany VALUES(pg_sleep(0.1) || $1, $2)
210+
''', [('a' * 32768, x) for x in range(128)], timeout=0.5)
211211
result = await self.con.fetch('SELECT * FROM exmany')
212212
self.assertEqual(result, [])
213213

214-
async def test_client_failure_in_transaction(self):
214+
async def test_executemany_timeout_flow_control(self):
215+
event = asyncio.Event()
216+
217+
async def locker():
218+
test_func = getattr(self, self._testMethodName).__func__
219+
opts = getattr(test_func, '__connect_options__', {})
220+
conn = await self.connect(**opts)
221+
try:
222+
tx = conn.transaction()
223+
await tx.start()
224+
await conn.execute("UPDATE exmany SET a = '1' WHERE b = 10")
225+
event.set()
226+
await asyncio.sleep(1)
227+
await tx.rollback()
228+
finally:
229+
event.set()
230+
await conn.close()
231+
232+
await self.con.executemany('''
233+
INSERT INTO exmany VALUES(NULL, $1)
234+
''', [(x,) for x in range(128)])
235+
fut = asyncio.ensure_future(locker())
236+
await event.wait()
237+
with self.assertRaises(asyncio.TimeoutError):
238+
await self.con.executemany('''
239+
UPDATE exmany SET a = $1 WHERE b = $2
240+
''', [('a' * 32768, x) for x in range(128)], timeout=0.5)
241+
await fut
242+
result = await self.con.fetch('SELECT * FROM exmany WHERE a IS NOT NULL')
243+
self.assertEqual(result, [])
244+
245+
async def test_executemany_client_failure_in_transaction(self):
215246
tx = self.con.transaction()
216247
await tx.start()
217248
with self.assertRaises(ZeroDivisionError):
@@ -226,7 +257,7 @@ async def test_client_failure_in_transaction(self):
226257
result = await self.con.fetch('SELECT b FROM exmany')
227258
self.assertEqual(result, [])
228259

229-
async def test_client_server_failure_conflict(self):
260+
async def test_executemany_client_server_failure_conflict(self):
230261
self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64)
231262
with self.assertRaises(UniqueViolationError):
232263
await self.con.executemany('''
@@ -235,7 +266,7 @@ async def test_client_server_failure_conflict(self):
235266
result = await self.con.fetch('SELECT b FROM exmany')
236267
self.assertEqual(result, [])
237268

238-
async def test_prepare(self):
269+
async def test_executemany_prepare(self):
239270
stmt = await self.con.prepare('''
240271
INSERT INTO exmany VALUES($1, $2)
241272
''')

0 commit comments

Comments
 (0)