@@ -39,13 +39,14 @@ class Connection(metaclass=ConnectionMeta):
39
39
40
40
__slots__ = ('_protocol' , '_transport' , '_loop' , '_types_stmt' ,
41
41
'_type_by_name_stmt' , '_top_xact' , '_uid' , '_aborted' ,
42
- '_stmt_cache_max_size' , ' _stmt_cache' , '_stmts_to_close' ,
42
+ '_stmt_cache' , '_stmts_to_close' ,
43
43
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
44
44
'_server_version' , '_server_caps' , '_intro_query' ,
45
45
'_reset_query' , '_proxy' , '_stmt_exclusive_section' )
46
46
47
47
def __init__ (self , protocol , transport , loop , addr , opts , * ,
48
- statement_cache_size , command_timeout ):
48
+ statement_cache_size , command_timeout ,
49
+ max_cached_statement_lifetime ):
49
50
self ._protocol = protocol
50
51
self ._transport = transport
51
52
self ._loop = loop
@@ -58,8 +59,12 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
58
59
self ._addr = addr
59
60
self ._opts = opts
60
61
61
- self ._stmt_cache_max_size = statement_cache_size
62
- self ._stmt_cache = collections .OrderedDict ()
62
+ self ._stmt_cache = _StatementCache (
63
+ loop = loop ,
64
+ max_size = statement_cache_size ,
65
+ on_remove = self ._maybe_gc_stmt ,
66
+ max_lifetime = max_cached_statement_lifetime )
67
+
63
68
self ._stmts_to_close = set ()
64
69
65
70
if command_timeout is not None :
@@ -126,6 +131,8 @@ async def add_listener(self, channel, callback):
126
131
127
132
async def remove_listener (self , channel , callback ):
128
133
"""Remove a listening callback on the specified channel."""
134
+ if self .is_closed ():
135
+ return
129
136
if channel not in self ._listeners :
130
137
return
131
138
if callback not in self ._listeners [channel ]:
@@ -266,46 +273,33 @@ async def executemany(self, command: str, args,
266
273
return await self ._executemany (command , args , timeout )
267
274
268
275
async def _get_statement (self , query , timeout , * , named : bool = False ):
269
- use_cache = self ._stmt_cache_max_size > 0
270
- if use_cache :
271
- try :
272
- state = self ._stmt_cache [query ]
273
- except KeyError :
274
- pass
275
- else :
276
- self ._stmt_cache .move_to_end (query , last = True )
277
- if not state .closed :
278
- return state
279
-
280
- protocol = self ._protocol
276
+ statement = self ._stmt_cache .get (query )
277
+ if statement is not None :
278
+ return statement
281
279
282
- if use_cache or named :
280
+ if self . _stmt_cache . get_max_size () or named :
283
281
stmt_name = self ._get_unique_id ('stmt' )
284
282
else :
285
283
stmt_name = ''
286
284
287
- state = await protocol .prepare (stmt_name , query , timeout )
285
+ statement = await self . _protocol .prepare (stmt_name , query , timeout )
288
286
289
- ready = state ._init_types ()
287
+ ready = statement ._init_types ()
290
288
if ready is not True :
291
289
if self ._types_stmt is None :
292
290
self ._types_stmt = await self .prepare (self ._intro_query )
293
291
294
292
types = await self ._types_stmt .fetch (list (ready ))
295
- protocol .get_settings ().register_data_types (types )
293
+ self . _protocol .get_settings ().register_data_types (types )
296
294
297
- if use_cache :
298
- if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
299
- old_query , old_state = self ._stmt_cache .popitem (last = False )
300
- self ._maybe_gc_stmt (old_state )
301
- self ._stmt_cache [query ] = state
295
+ self ._stmt_cache .put (query , statement )
302
296
303
297
# If we've just created a new statement object, check if there
304
298
# are any statements for GC.
305
299
if self ._stmts_to_close :
306
300
await self ._cleanup_stmts ()
307
301
308
- return state
302
+ return statement
309
303
310
304
def cursor (self , query , * args , prefetch = None , timeout = None ):
311
305
"""Return a *cursor factory* for the specified query.
@@ -457,14 +451,14 @@ async def close(self):
457
451
"""Close the connection gracefully."""
458
452
if self .is_closed ():
459
453
return
460
- self ._close_stmts ()
454
+ self ._mark_stmts_as_closed ()
461
455
self ._listeners = {}
462
456
self ._aborted = True
463
457
await self ._protocol .close ()
464
458
465
459
def terminate (self ):
466
460
"""Terminate the connection without waiting for pending data."""
467
- self ._close_stmts ()
461
+ self ._mark_stmts_as_closed ()
468
462
self ._listeners = {}
469
463
self ._aborted = True
470
464
self ._protocol .abort ()
@@ -484,8 +478,8 @@ def _get_unique_id(self, prefix):
484
478
self ._uid += 1
485
479
return '__asyncpg_{}_{}__' .format (prefix , self ._uid )
486
480
487
- def _close_stmts (self ):
488
- for stmt in self ._stmt_cache .values ():
481
+ def _mark_stmts_as_closed (self ):
482
+ for stmt in self ._stmt_cache .iter_statements ():
489
483
stmt .mark_closed ()
490
484
491
485
for stmt in self ._stmts_to_close :
@@ -495,11 +489,22 @@ def _close_stmts(self):
495
489
self ._stmts_to_close .clear ()
496
490
497
491
def _maybe_gc_stmt (self , stmt ):
498
- if stmt .refs == 0 and stmt .query not in self ._stmt_cache :
492
+ if stmt .refs == 0 and not self ._stmt_cache .has (stmt .query ):
493
+ # If low-level `stmt` isn't referenced from any high-level
494
+ # `PreparedStatament` object and is not in the `_stmt_cache`:
495
+ #
496
+ # * mark it as closed, which will make it non-usable
497
+ # for any `PreparedStatament` or for methods like
498
+ # `Connection.fetch()`.
499
+ #
500
+ # * schedule it to be formally closed on the server.
499
501
stmt .mark_closed ()
500
502
self ._stmts_to_close .add (stmt )
501
503
502
504
async def _cleanup_stmts (self ):
505
+ # Called whenever we create a new prepared statement in
506
+ # `Connection._get_statement()` and `_stmts_to_close` is
507
+ # not empty.
503
508
to_close = self ._stmts_to_close
504
509
self ._stmts_to_close = set ()
505
510
for stmt in to_close :
@@ -700,6 +705,7 @@ async def connect(dsn=None, *,
700
705
loop = None ,
701
706
timeout = 60 ,
702
707
statement_cache_size = 100 ,
708
+ max_cached_statement_lifetime = 300 ,
703
709
command_timeout = None ,
704
710
__connection_class__ = Connection ,
705
711
** opts ):
@@ -735,6 +741,12 @@ async def connect(dsn=None, *,
735
741
:param float timeout: connection timeout in seconds.
736
742
737
743
:param int statement_cache_size: the size of prepared statement LRU cache.
744
+ Pass ``0`` to disable the cache.
745
+
746
+ :param int max_cached_statement_lifetime:
747
+ the maximum time in seconds a prepared statement will stay
748
+ in the cache. Pass ``0`` to allow statements be cached
749
+ indefinitely.
738
750
739
751
:param float command_timeout: the default timeout for operations on
740
752
this connection (the default is no timeout).
@@ -753,6 +765,9 @@ async def connect(dsn=None, *,
753
765
... print(types)
754
766
>>> asyncio.get_event_loop().run_until_complete(run())
755
767
[<Record typname='bool' typnamespace=11 ...
768
+
769
+ .. versionchanged:: 0.10.0
770
+ Added ``max_cached_statement_use_count`` parameter.
756
771
"""
757
772
if loop is None :
758
773
loop = asyncio .get_event_loop ()
@@ -796,13 +811,162 @@ async def connect(dsn=None, *,
796
811
tr .close ()
797
812
raise
798
813
799
- con = __connection_class__ (pr , tr , loop , addr , opts ,
800
- statement_cache_size = statement_cache_size ,
801
- command_timeout = command_timeout )
814
+ con = __connection_class__ (
815
+ pr , tr , loop , addr , opts ,
816
+ statement_cache_size = statement_cache_size ,
817
+ max_cached_statement_lifetime = max_cached_statement_lifetime ,
818
+ command_timeout = command_timeout )
819
+
802
820
pr .set_connection (con )
803
821
return con
804
822
805
823
824
+ class _StatementCacheEntry :
825
+
826
+ __slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
827
+
828
+ def __init__ (self , cache , query , statement ):
829
+ self ._cache = cache
830
+ self ._query = query
831
+ self ._statement = statement
832
+ self ._cleanup_cb = None
833
+
834
+
835
+ class _StatementCache :
836
+
837
+ __slots__ = ('_loop' , '_entries' , '_max_size' , '_on_remove' ,
838
+ '_max_lifetime' )
839
+
840
+ def __init__ (self , * , loop , max_size , on_remove , max_lifetime ):
841
+ self ._loop = loop
842
+ self ._max_size = max_size
843
+ self ._on_remove = on_remove
844
+ self ._max_lifetime = max_lifetime
845
+
846
+ # We use an OrderedDict for LRU implementation. Operations:
847
+ #
848
+ # * We use a simple `__setitem__` to push a new entry:
849
+ # `entries[key] = new_entry`
850
+ # That will push `new_entry` to the *end* of the entries dict.
851
+ #
852
+ # * When we have a cache hit, we call
853
+ # `entries.move_to_end(key, last=True)`
854
+ # to move the entry to the *end* of the entries dict.
855
+ #
856
+ # * When we need to remove entries to maintain `max_size`, we call
857
+ # `entries.popitem(last=False)`
858
+ # to remove an entry from the *beginning* of the entries dict.
859
+ #
860
+ # So new entries and hits are always promoted to the end of the
861
+ # entries dict, whereas the unused one will group in the
862
+ # beginning of it.
863
+ self ._entries = collections .OrderedDict ()
864
+
865
+ def __len__ (self ):
866
+ return len (self ._entries )
867
+
868
+ def get_max_size (self ):
869
+ return self ._max_size
870
+
871
+ def set_max_size (self , new_size ):
872
+ assert new_size >= 0
873
+ self ._max_size = new_size
874
+ self ._maybe_cleanup ()
875
+
876
+ def get_max_lifetime (self ):
877
+ return self ._max_lifetime
878
+
879
+ def set_max_lifetime (self , new_lifetime ):
880
+ assert new_lifetime >= 0
881
+ self ._max_lifetime = new_lifetime
882
+ for entry in self ._entries .values ():
883
+ # For every entry cancel the existing callback
884
+ # and setup a new one if necessary.
885
+ self ._set_entry_timeout (entry )
886
+
887
+ def get (self , query , * , promote = True ):
888
+ if not self ._max_size :
889
+ # The cache is disabled.
890
+ return
891
+
892
+ entry = self ._entries .get (query ) # type: _StatementCacheEntry
893
+ if entry is None :
894
+ return
895
+
896
+ if entry ._statement .closed :
897
+ # Happens in unittests when we call `stmt._state.mark_closed()`
898
+ # manually.
899
+ self ._entries .pop (query )
900
+ self ._clear_entry_callback (entry )
901
+ return
902
+
903
+ if promote :
904
+ # `promote` is `False` when `get()` is called by `has()`.
905
+ self ._entries .move_to_end (query , last = True )
906
+
907
+ return entry ._statement
908
+
909
+ def has (self , query ):
910
+ return self .get (query , promote = False ) is not None
911
+
912
+ def put (self , query , statement ):
913
+ if not self ._max_size :
914
+ # The cache is disabled.
915
+ return
916
+
917
+ self ._entries [query ] = self ._new_entry (query , statement )
918
+
919
+ # Check if the cache is bigger than max_size and trim it
920
+ # if necessary.
921
+ self ._maybe_cleanup ()
922
+
923
+ def iter_statements (self ):
924
+ return (e ._statement for e in self ._entries .values ())
925
+
926
+ def clear (self ):
927
+ # First, make sure that we cancel all scheduled callbacks.
928
+ for entry in self ._entries .values ():
929
+ self ._clear_entry_callback (entry )
930
+
931
+ # Clear the entries dict.
932
+ self ._entries .clear ()
933
+
934
+ def _set_entry_timeout (self , entry ):
935
+ # Clear the existing timeout.
936
+ self ._clear_entry_callback (entry )
937
+
938
+ # Set the new timeout if it's not 0.
939
+ if self ._max_lifetime :
940
+ entry ._cleanup_cb = self ._loop .call_later (
941
+ self ._max_lifetime , self ._on_entry_expired , entry )
942
+
943
+ def _new_entry (self , query , statement ):
944
+ entry = _StatementCacheEntry (self , query , statement )
945
+ self ._set_entry_timeout (entry )
946
+ return entry
947
+
948
+ def _on_entry_expired (self , entry ):
949
+ # `call_later` callback, called when an entry stayed longer
950
+ # than `self._max_lifetime`.
951
+ if self ._entries .get (entry ._query ) is entry :
952
+ self ._entries .pop (entry ._query )
953
+ self ._on_remove (entry ._statement )
954
+
955
+ def _clear_entry_callback (self , entry ):
956
+ if entry ._cleanup_cb is not None :
957
+ entry ._cleanup_cb .cancel ()
958
+
959
+ def _maybe_cleanup (self ):
960
+ # Delete cache entries until the size of the cache is `max_size`.
961
+ while len (self ._entries ) > self ._max_size :
962
+ old_query , old_entry = self ._entries .popitem (last = False )
963
+ self ._clear_entry_callback (old_entry )
964
+
965
+ # Let the connection know that the statement was removed
966
+ # from the cache.
967
+ self ._on_remove (old_entry ._statement )
968
+
969
+
806
970
class _Atomic :
807
971
__slots__ = ('_acquired' ,)
808
972
0 commit comments