Skip to content

Commit 10955fd

Browse files
committed
Add new max_cached_statement_lifetime parameter to Connection.
The parameter allows asyncpg to refresh cached prepared statements periodically. See also issue #76.
1 parent bb98b79 commit 10955fd

File tree

4 files changed

+302
-73
lines changed

4 files changed

+302
-73
lines changed

asyncpg/_testbase.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,32 @@ def start_cluster(cls, ClusterCls, *,
190190
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)
191191

192192

193+
def with_connection_options(**options):
194+
if not options:
195+
raise ValueError('no connection options were specified')
196+
197+
def wrap(func):
198+
func.__connect_options__ = options
199+
return func
200+
201+
return wrap
202+
203+
193204
class ConnectedTestCase(ClusterTestCase):
194205

195206
def getExtraConnectOptions(self):
196207
return {}
197208

198209
def setUp(self):
199210
super().setUp()
200-
opts = self.getExtraConnectOptions()
211+
212+
# Extract options set up with `with_connection_options`.
213+
test_func = getattr(self, self._testMethodName).__func__
214+
opts = getattr(test_func, '__connect_options__', {})
215+
201216
self.con = self.loop.run_until_complete(
202217
self.cluster.connect(database='postgres', loop=self.loop, **opts))
218+
203219
self.server_version = self.con.get_server_version()
204220

205221
def tearDown(self):

asyncpg/connection.py

+198-34
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ class Connection(metaclass=ConnectionMeta):
3939

4040
__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
4141
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
42-
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
42+
'_stmt_cache', '_stmts_to_close',
4343
'_addr', '_opts', '_command_timeout', '_listeners',
4444
'_server_version', '_server_caps', '_intro_query',
4545
'_reset_query', '_proxy', '_stmt_exclusive_section')
4646

4747
def __init__(self, protocol, transport, loop, addr, opts, *,
48-
statement_cache_size, command_timeout):
48+
statement_cache_size, command_timeout,
49+
max_cached_statement_lifetime):
4950
self._protocol = protocol
5051
self._transport = transport
5152
self._loop = loop
@@ -58,8 +59,12 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5859
self._addr = addr
5960
self._opts = opts
6061

61-
self._stmt_cache_max_size = statement_cache_size
62-
self._stmt_cache = collections.OrderedDict()
62+
self._stmt_cache = _StatementCache(
63+
loop=loop,
64+
max_size=statement_cache_size,
65+
on_remove=self._maybe_gc_stmt,
66+
max_lifetime=max_cached_statement_lifetime)
67+
6368
self._stmts_to_close = set()
6469

6570
if command_timeout is not None:
@@ -126,6 +131,8 @@ async def add_listener(self, channel, callback):
126131

127132
async def remove_listener(self, channel, callback):
128133
"""Remove a listening callback on the specified channel."""
134+
if self.is_closed():
135+
return
129136
if channel not in self._listeners:
130137
return
131138
if callback not in self._listeners[channel]:
@@ -266,46 +273,33 @@ async def executemany(self, command: str, args,
266273
return await self._executemany(command, args, timeout)
267274

268275
async def _get_statement(self, query, timeout, *, named: bool=False):
269-
use_cache = self._stmt_cache_max_size > 0
270-
if use_cache:
271-
try:
272-
state = self._stmt_cache[query]
273-
except KeyError:
274-
pass
275-
else:
276-
self._stmt_cache.move_to_end(query, last=True)
277-
if not state.closed:
278-
return state
279-
280-
protocol = self._protocol
276+
statement = self._stmt_cache.get(query)
277+
if statement is not None:
278+
return statement
281279

282-
if use_cache or named:
280+
if self._stmt_cache.get_max_size() or named:
283281
stmt_name = self._get_unique_id('stmt')
284282
else:
285283
stmt_name = ''
286284

287-
state = await protocol.prepare(stmt_name, query, timeout)
285+
statement = await self._protocol.prepare(stmt_name, query, timeout)
288286

289-
ready = state._init_types()
287+
ready = statement._init_types()
290288
if ready is not True:
291289
if self._types_stmt is None:
292290
self._types_stmt = await self.prepare(self._intro_query)
293291

294292
types = await self._types_stmt.fetch(list(ready))
295-
protocol.get_settings().register_data_types(types)
293+
self._protocol.get_settings().register_data_types(types)
296294

297-
if use_cache:
298-
if len(self._stmt_cache) > self._stmt_cache_max_size - 1:
299-
old_query, old_state = self._stmt_cache.popitem(last=False)
300-
self._maybe_gc_stmt(old_state)
301-
self._stmt_cache[query] = state
295+
self._stmt_cache.put(query, statement)
302296

303297
# If we've just created a new statement object, check if there
304298
# are any statements for GC.
305299
if self._stmts_to_close:
306300
await self._cleanup_stmts()
307301

308-
return state
302+
return statement
309303

310304
def cursor(self, query, *args, prefetch=None, timeout=None):
311305
"""Return a *cursor factory* for the specified query.
@@ -457,14 +451,14 @@ async def close(self):
457451
"""Close the connection gracefully."""
458452
if self.is_closed():
459453
return
460-
self._close_stmts()
454+
self._mark_stmts_as_closed()
461455
self._listeners = {}
462456
self._aborted = True
463457
await self._protocol.close()
464458

465459
def terminate(self):
466460
"""Terminate the connection without waiting for pending data."""
467-
self._close_stmts()
461+
self._mark_stmts_as_closed()
468462
self._listeners = {}
469463
self._aborted = True
470464
self._protocol.abort()
@@ -484,8 +478,8 @@ def _get_unique_id(self, prefix):
484478
self._uid += 1
485479
return '__asyncpg_{}_{}__'.format(prefix, self._uid)
486480

487-
def _close_stmts(self):
488-
for stmt in self._stmt_cache.values():
481+
def _mark_stmts_as_closed(self):
482+
for stmt in self._stmt_cache.iter_statements():
489483
stmt.mark_closed()
490484

491485
for stmt in self._stmts_to_close:
@@ -495,11 +489,22 @@ def _close_stmts(self):
495489
self._stmts_to_close.clear()
496490

497491
def _maybe_gc_stmt(self, stmt):
498-
if stmt.refs == 0 and stmt.query not in self._stmt_cache:
492+
if stmt.refs == 0 and not self._stmt_cache.has(stmt.query):
493+
# If low-level `stmt` isn't referenced from any high-level
494+
# `PreparedStatament` object and is not in the `_stmt_cache`:
495+
#
496+
# * mark it as closed, which will make it non-usable
497+
# for any `PreparedStatament` or for methods like
498+
# `Connection.fetch()`.
499+
#
500+
# * schedule it to be formally closed on the server.
499501
stmt.mark_closed()
500502
self._stmts_to_close.add(stmt)
501503

502504
async def _cleanup_stmts(self):
505+
# Called whenever we create a new prepared statement in
506+
# `Connection._get_statement()` and `_stmts_to_close` is
507+
# not empty.
503508
to_close = self._stmts_to_close
504509
self._stmts_to_close = set()
505510
for stmt in to_close:
@@ -700,6 +705,7 @@ async def connect(dsn=None, *,
700705
loop=None,
701706
timeout=60,
702707
statement_cache_size=100,
708+
max_cached_statement_lifetime=300,
703709
command_timeout=None,
704710
__connection_class__=Connection,
705711
**opts):
@@ -735,6 +741,12 @@ async def connect(dsn=None, *,
735741
:param float timeout: connection timeout in seconds.
736742
737743
:param int statement_cache_size: the size of prepared statement LRU cache.
744+
Pass ``0`` to disable the cache.
745+
746+
:param int max_cached_statement_lifetime:
747+
the maximum time in seconds a prepared statement will stay
748+
in the cache. Pass ``0`` to allow statements be cached
749+
indefinitely.
738750
739751
:param float command_timeout: the default timeout for operations on
740752
this connection (the default is no timeout).
@@ -753,6 +765,9 @@ async def connect(dsn=None, *,
753765
... print(types)
754766
>>> asyncio.get_event_loop().run_until_complete(run())
755767
[<Record typname='bool' typnamespace=11 ...
768+
769+
.. versionchanged:: 0.10.0
770+
Added ``max_cached_statement_use_count`` parameter.
756771
"""
757772
if loop is None:
758773
loop = asyncio.get_event_loop()
@@ -796,13 +811,162 @@ async def connect(dsn=None, *,
796811
tr.close()
797812
raise
798813

799-
con = __connection_class__(pr, tr, loop, addr, opts,
800-
statement_cache_size=statement_cache_size,
801-
command_timeout=command_timeout)
814+
con = __connection_class__(
815+
pr, tr, loop, addr, opts,
816+
statement_cache_size=statement_cache_size,
817+
max_cached_statement_lifetime=max_cached_statement_lifetime,
818+
command_timeout=command_timeout)
819+
802820
pr.set_connection(con)
803821
return con
804822

805823

824+
class _StatementCacheEntry:
825+
826+
__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
827+
828+
def __init__(self, cache, query, statement):
829+
self._cache = cache
830+
self._query = query
831+
self._statement = statement
832+
self._cleanup_cb = None
833+
834+
835+
class _StatementCache:
836+
837+
__slots__ = ('_loop', '_entries', '_max_size', '_on_remove',
838+
'_max_lifetime')
839+
840+
def __init__(self, *, loop, max_size, on_remove, max_lifetime):
841+
self._loop = loop
842+
self._max_size = max_size
843+
self._on_remove = on_remove
844+
self._max_lifetime = max_lifetime
845+
846+
# We use an OrderedDict for LRU implementation. Operations:
847+
#
848+
# * We use a simple `__setitem__` to push a new entry:
849+
# `entries[key] = new_entry`
850+
# That will push `new_entry` to the *end* of the entries dict.
851+
#
852+
# * When we have a cache hit, we call
853+
# `entries.move_to_end(key, last=True)`
854+
# to move the entry to the *end* of the entries dict.
855+
#
856+
# * When we need to remove entries to maintain `max_size`, we call
857+
# `entries.popitem(last=False)`
858+
# to remove an entry from the *beginning* of the entries dict.
859+
#
860+
# So new entries and hits are always promoted to the end of the
861+
# entries dict, whereas the unused one will group in the
862+
# beginning of it.
863+
self._entries = collections.OrderedDict()
864+
865+
def __len__(self):
866+
return len(self._entries)
867+
868+
def get_max_size(self):
869+
return self._max_size
870+
871+
def set_max_size(self, new_size):
872+
assert new_size >= 0
873+
self._max_size = new_size
874+
self._maybe_cleanup()
875+
876+
def get_max_lifetime(self):
877+
return self._max_lifetime
878+
879+
def set_max_lifetime(self, new_lifetime):
880+
assert new_lifetime >= 0
881+
self._max_lifetime = new_lifetime
882+
for entry in self._entries.values():
883+
# For every entry cancel the existing callback
884+
# and setup a new one if necessary.
885+
self._set_entry_timeout(entry)
886+
887+
def get(self, query, *, promote=True):
888+
if not self._max_size:
889+
# The cache is disabled.
890+
return
891+
892+
entry = self._entries.get(query) # type: _StatementCacheEntry
893+
if entry is None:
894+
return
895+
896+
if entry._statement.closed:
897+
# Happens in unittests when we call `stmt._state.mark_closed()`
898+
# manually.
899+
self._entries.pop(query)
900+
self._clear_entry_callback(entry)
901+
return
902+
903+
if promote:
904+
# `promote` is `False` when `get()` is called by `has()`.
905+
self._entries.move_to_end(query, last=True)
906+
907+
return entry._statement
908+
909+
def has(self, query):
910+
return self.get(query, promote=False) is not None
911+
912+
def put(self, query, statement):
913+
if not self._max_size:
914+
# The cache is disabled.
915+
return
916+
917+
self._entries[query] = self._new_entry(query, statement)
918+
919+
# Check if the cache is bigger than max_size and trim it
920+
# if necessary.
921+
self._maybe_cleanup()
922+
923+
def iter_statements(self):
924+
return (e._statement for e in self._entries.values())
925+
926+
def clear(self):
927+
# First, make sure that we cancel all scheduled callbacks.
928+
for entry in self._entries.values():
929+
self._clear_entry_callback(entry)
930+
931+
# Clear the entries dict.
932+
self._entries.clear()
933+
934+
def _set_entry_timeout(self, entry):
935+
# Clear the existing timeout.
936+
self._clear_entry_callback(entry)
937+
938+
# Set the new timeout if it's not 0.
939+
if self._max_lifetime:
940+
entry._cleanup_cb = self._loop.call_later(
941+
self._max_lifetime, self._on_entry_expired, entry)
942+
943+
def _new_entry(self, query, statement):
944+
entry = _StatementCacheEntry(self, query, statement)
945+
self._set_entry_timeout(entry)
946+
return entry
947+
948+
def _on_entry_expired(self, entry):
949+
# `call_later` callback, called when an entry stayed longer
950+
# than `self._max_lifetime`.
951+
if self._entries.get(entry._query) is entry:
952+
self._entries.pop(entry._query)
953+
self._on_remove(entry._statement)
954+
955+
def _clear_entry_callback(self, entry):
956+
if entry._cleanup_cb is not None:
957+
entry._cleanup_cb.cancel()
958+
959+
def _maybe_cleanup(self):
960+
# Delete cache entries until the size of the cache is `max_size`.
961+
while len(self._entries) > self._max_size:
962+
old_query, old_entry = self._entries.popitem(last=False)
963+
self._clear_entry_callback(old_entry)
964+
965+
# Let the connection know that the statement was removed
966+
# from the cache.
967+
self._on_remove(old_entry._statement)
968+
969+
806970
class _Atomic:
807971
__slots__ = ('_acquired',)
808972

0 commit comments

Comments
 (0)