Skip to content

Commit 1674dec

Browse files
elprans1st1
authored andcommitted
Implement conistent handling of timeouts.
The timeout logic is currently a bit of a mess. This commit attempts to tidy things up. Most importantly, the timeout budget is now applied consistently to the _entire_ call, whereas previously multiple consecutive operations used the same timeout value, making it possible for the overall run time to exceed the timeout. Secondly, tighten the validation for timeouts: booleans are not accepted, and neither any value that cannot be converted to float.
1 parent 749d857 commit 1674dec

File tree

5 files changed

+150
-34
lines changed

5 files changed

+150
-34
lines changed

asyncpg/connection.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import socket
1313
import struct
14+
import time
1415
import urllib.parse
1516

1617
from . import cursor
@@ -60,6 +61,22 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6061
self._stmt_cache = collections.OrderedDict()
6162
self._stmts_to_close = set()
6263

64+
if command_timeout is not None:
65+
try:
66+
if isinstance(command_timeout, bool):
67+
raise ValueError
68+
69+
command_timeout = float(command_timeout)
70+
71+
if command_timeout < 0:
72+
raise ValueError
73+
74+
except ValueError:
75+
raise ValueError(
76+
'invalid command_timeout value: '
77+
'expected non-negative float (got {!r})'.format(
78+
command_timeout)) from None
79+
6380
self._command_timeout = command_timeout
6481

6582
self._listeners = {}
@@ -187,7 +204,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
187204
if not args:
188205
return await self._protocol.query(query, timeout)
189206

190-
_, status, _ = await self._do_execute(query, args, 0, timeout, True)
207+
_, status, _ = await self._execute(query, args, 0, timeout, True)
191208
return status.decode()
192209

193210
async def executemany(self, command: str, args, timeout: float=None):
@@ -208,8 +225,7 @@ async def executemany(self, command: str, args, timeout: float=None):
208225
209226
.. versionadded:: 0.7.0
210227
"""
211-
stmt = await self._get_statement(command, timeout)
212-
return await self._protocol.bind_execute_many(stmt, args, '', timeout)
228+
return await self._executemany(command, args, timeout)
213229

214230
async def _get_statement(self, query, timeout):
215231
cache = self._stmt_cache_max_size > 0
@@ -281,7 +297,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
281297
282298
:return list: A list of :class:`Record` instances.
283299
"""
284-
return await self._do_execute(query, args, 0, timeout)
300+
return await self._execute(query, args, 0, timeout)
285301

286302
async def fetchval(self, query, *args, column=0, timeout=None):
287303
"""Run a query and return a value in the first row.
@@ -297,7 +313,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
297313
298314
:return: The value of the specified column of the first record.
299315
"""
300-
data = await self._do_execute(query, args, 1, timeout)
316+
data = await self._execute(query, args, 1, timeout)
301317
if not data:
302318
return None
303319
return data[0][column]
@@ -311,7 +327,7 @@ async def fetchrow(self, query, *args, timeout=None):
311327
312328
:return: The first row as a :class:`Record` instance.
313329
"""
314-
data = await self._do_execute(query, args, 1, timeout)
330+
data = await self._execute(query, args, 1, timeout)
315331
if not data:
316332
return None
317333
return data[0]
@@ -430,7 +446,9 @@ async def _cleanup_stmts(self):
430446
to_close = self._stmts_to_close
431447
self._stmts_to_close = set()
432448
for stmt in to_close:
433-
await self._protocol.close_statement(stmt, False)
449+
# It is imperative that statements are cleaned properly,
450+
# so we ignore the timeout.
451+
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
434452

435453
def _request_portal_name(self):
436454
return self._get_unique_id()
@@ -554,13 +572,37 @@ def _drop_global_statement_cache(self):
554572
else:
555573
self._drop_local_statement_cache()
556574

557-
async def _do_execute(self, query, args, limit, timeout,
558-
return_status=False):
559-
stmt = await self._get_statement(query, timeout)
575+
def _execute(self, query, args, limit, timeout, return_status=False):
576+
executor = lambda stmt, timeout: self._protocol.bind_execute(
577+
stmt, args, '', limit, return_status, timeout)
578+
timeout = self._protocol._get_timeout(timeout)
579+
return self._do_execute(query, executor, timeout)
580+
581+
def _executemany(self, query, args, timeout):
582+
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
583+
stmt, args, '', timeout)
584+
timeout = self._protocol._get_timeout(timeout)
585+
return self._do_execute(query, executor, timeout)
586+
587+
async def _do_execute(self, query, executor, timeout, retry=True):
588+
if timeout is None:
589+
stmt = await self._get_statement(query, None)
590+
else:
591+
before = time.monotonic()
592+
stmt = await self._get_statement(query, timeout)
593+
after = time.monotonic()
594+
timeout -= after - before
595+
before = after
560596

561597
try:
562-
result = await self._protocol.bind_execute(
563-
stmt, args, '', limit, return_status, timeout)
598+
if timeout is None:
599+
result = await executor(stmt, None)
600+
else:
601+
try:
602+
result = await executor(stmt, timeout)
603+
finally:
604+
after = time.monotonic()
605+
timeout -= after - before
564606

565607
except exceptions.InvalidCachedStatementError as e:
566608
# PostgreSQL will raise an exception when it detects
@@ -586,13 +628,11 @@ async def _do_execute(self, query, args, limit, timeout,
586628
# for discussion.
587629
#
588630
self._drop_global_statement_cache()
589-
590-
if self._protocol.is_in_transaction():
631+
if self._protocol.is_in_transaction() or not retry:
591632
raise
592633
else:
593-
stmt = await self._get_statement(query, timeout)
594-
result = await self._protocol.bind_execute(
595-
stmt, args, '', limit, return_status, timeout)
634+
result = await self._do_execute(
635+
query, executor, timeout, retry=False)
596636

597637
return result
598638

asyncpg/protocol/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
from .protocol import Protocol, Record # NOQA
8+
from .protocol import Protocol, Record, NO_TIMEOUT # NOQA

asyncpg/protocol/protocol.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ cdef class BaseProtocol(CoreProtocol):
4747

4848
PreparedStatementState statement
4949

50-
cdef _ensure_clear_state(self)
50+
cdef _get_timeout_impl(self, timeout)
51+
cdef _check_state(self)
5152
cdef _new_waiter(self, timeout)
5253

5354
cdef _on_result__connect(self, object waiter)

asyncpg/protocol/protocol.pyx

+49-13
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ include "coreproto.pyx"
7979
include "prepared_stmt.pyx"
8080

8181

82+
NO_TIMEOUT = object()
83+
84+
8285
cdef class BaseProtocol(CoreProtocol):
8386
def __init__(self, addr, connected_fut, con_args, loop):
8487
CoreProtocol.__init__(self, con_args)
@@ -132,7 +135,8 @@ cdef class BaseProtocol(CoreProtocol):
132135
await self.cancel_sent_waiter
133136
self.cancel_sent_waiter = None
134137

135-
self._ensure_clear_state()
138+
self._check_state()
139+
timeout = self._get_timeout_impl(timeout)
136140

137141
if stmt_name is None:
138142
self.uid_counter += 1
@@ -154,7 +158,8 @@ cdef class BaseProtocol(CoreProtocol):
154158
await self.cancel_sent_waiter
155159
self.cancel_sent_waiter = None
156160

157-
self._ensure_clear_state()
161+
self._check_state()
162+
timeout = self._get_timeout_impl(timeout)
158163

159164
self._bind_execute(
160165
portal_name,
@@ -178,7 +183,8 @@ cdef class BaseProtocol(CoreProtocol):
178183
await self.cancel_sent_waiter
179184
self.cancel_sent_waiter = None
180185

181-
self._ensure_clear_state()
186+
self._check_state()
187+
timeout = self._get_timeout_impl(timeout)
182188

183189
# Make sure the argument sequence is encoded lazily with
184190
# this generator expression to keep the memory pressure under
@@ -209,7 +215,8 @@ cdef class BaseProtocol(CoreProtocol):
209215
await self.cancel_sent_waiter
210216
self.cancel_sent_waiter = None
211217

212-
self._ensure_clear_state()
218+
self._check_state()
219+
timeout = self._get_timeout_impl(timeout)
213220

214221
self._bind(
215222
portal_name,
@@ -231,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol):
231238
await self.cancel_sent_waiter
232239
self.cancel_sent_waiter = None
233240

234-
self._ensure_clear_state()
241+
self._check_state()
242+
timeout = self._get_timeout_impl(timeout)
235243

236244
self._execute(
237245
portal_name,
@@ -251,7 +259,11 @@ cdef class BaseProtocol(CoreProtocol):
251259
await self.cancel_sent_waiter
252260
self.cancel_sent_waiter = None
253261

254-
self._ensure_clear_state()
262+
self._check_state()
263+
# query() needs to call _get_timeout instead of _get_timeout_impl
264+
# for consistent validation, as it is called differently from
265+
# prepare/bind/execute methods.
266+
timeout = self._get_timeout(timeout)
255267

256268
self._simple_query(query)
257269
self.last_query = query
@@ -266,7 +278,8 @@ cdef class BaseProtocol(CoreProtocol):
266278
await self.cancel_sent_waiter
267279
self.cancel_sent_waiter = None
268280

269-
self._ensure_clear_state()
281+
self._check_state()
282+
timeout = self._get_timeout_impl(timeout)
270283

271284
if state.refs != 0:
272285
raise RuntimeError(
@@ -348,7 +361,32 @@ cdef class BaseProtocol(CoreProtocol):
348361
cdef _set_server_parameter(self, name, val):
349362
self.settings.add_setting(name, val)
350363

351-
cdef _ensure_clear_state(self):
364+
def _get_timeout(self, timeout):
365+
if timeout is not None:
366+
try:
367+
if type(timeout) is bool:
368+
raise ValueError
369+
timeout = float(timeout)
370+
except ValueError:
371+
raise ValueError(
372+
'invalid timeout value: expected non-negative float '
373+
'(got {!r})'.format(timeout)) from None
374+
375+
return self._get_timeout_impl(timeout)
376+
377+
cdef inline _get_timeout_impl(self, timeout):
378+
if timeout is None:
379+
timeout = self.connection._command_timeout
380+
elif timeout is NO_TIMEOUT:
381+
timeout = None
382+
else:
383+
timeout = float(timeout)
384+
385+
if timeout is not None and timeout <= 0:
386+
raise asyncio.TimeoutError()
387+
return timeout
388+
389+
cdef _check_state(self):
352390
if self.cancel_waiter is not None:
353391
raise apg_exc.InterfaceError(
354392
'cannot perform operation: another operation is cancelling')
@@ -361,11 +399,9 @@ cdef class BaseProtocol(CoreProtocol):
361399

362400
cdef _new_waiter(self, timeout):
363401
self.waiter = self.create_future()
364-
if timeout is not False:
365-
timeout = timeout or self.connection._command_timeout
366-
if timeout is not None and timeout > 0:
367-
self.timeout_handle = self.connection._loop.call_later(
368-
timeout, self.timeout_callback, self.waiter)
402+
if timeout is not None:
403+
self.timeout_handle = self.connection._loop.call_later(
404+
timeout, self.timeout_callback, self.waiter)
369405
self.waiter.add_done_callback(self.completed_callback)
370406
return self.waiter
371407

tests/test_timeout.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77

88
import asyncio
9-
import asyncpg
109

10+
import asyncpg
11+
from asyncpg import connection as pg_connection
1112
from asyncpg import _testbase as tb
1213

1314

@@ -108,12 +109,28 @@ async def test_timeout_06(self):
108109

109110
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
110111

112+
async def test_invalid_timeout(self):
113+
for command_timeout in ('a', False, -1):
114+
with self.subTest(command_timeout=command_timeout):
115+
with self.assertRaisesRegex(ValueError,
116+
'invalid command_timeout'):
117+
await self.cluster.connect(
118+
database='postgres', loop=self.loop,
119+
command_timeout=command_timeout)
120+
121+
# Note: negative timeouts are OK for method calls.
122+
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
123+
for timeout in ('a', False):
124+
with self.subTest(timeout=timeout):
125+
with self.assertRaisesRegex(ValueError, 'invalid timeout'):
126+
await self.con.execute('SELECT 1', timeout=timeout)
127+
111128

112129
class TestConnectionCommandTimeout(tb.ConnectedTestCase):
113130

114131
def getExtraConnectOptions(self):
115132
return {
116-
'command_timeout': 0.02
133+
'command_timeout': 0.2
117134
}
118135

119136
async def test_command_timeout_01(self):
@@ -123,3 +140,25 @@ async def test_command_timeout_01(self):
123140
meth = getattr(self.con, methname)
124141
await meth('select pg_sleep(10)')
125142
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
143+
144+
145+
class SlowPrepareConnection(pg_connection.Connection):
146+
"""Connection class to test timeouts."""
147+
async def _get_statement(self, query, timeout):
148+
await asyncio.sleep(0.15, loop=self._loop)
149+
return await super()._get_statement(query, timeout)
150+
151+
152+
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):
153+
154+
def getExtraConnectOptions(self):
155+
return {
156+
'__connection_class__': SlowPrepareConnection,
157+
'command_timeout': 0.3
158+
}
159+
160+
async def test_timeout_covers_prepare_01(self):
161+
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
162+
with self.assertRaises(asyncio.TimeoutError):
163+
meth = getattr(self.con, methname)
164+
await meth('select pg_sleep($1)', 0.2)

0 commit comments

Comments
 (0)