@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
41
41
'_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
42
42
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
43
43
'_server_version' , '_server_caps' , '_intro_query' ,
44
- '_reset_query' , '_proxy' )
44
+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
45
45
46
46
def __init__ (self , protocol , transport , loop , addr , opts , * ,
47
47
statement_cache_size , command_timeout ):
@@ -97,6 +97,15 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
97
97
self ._reset_query = None
98
98
self ._proxy = None
99
99
100
+ # Used to serialize operations that might involve anonymous
101
+ # statements. Specifically, we want to make the following
102
+ # operation atomic:
103
+ # ("prepare an anonymous statement", "use the statement")
104
+ #
105
+ # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`,
106
+ # `con.execute()`, and `con.executemany()`.
107
+ self ._stmt_exclusive_section = _Atomic ()
108
+
100
109
async def add_listener (self , channel , callback ):
101
110
"""Add a listener for Postgres notifications.
102
111
@@ -227,10 +236,9 @@ async def executemany(self, command: str, args, timeout: float=None):
227
236
"""
228
237
return await self ._executemany (command , args , timeout )
229
238
230
- async def _get_statement (self , query , timeout ):
231
- cache = self ._stmt_cache_max_size > 0
232
-
233
- if cache :
239
+ async def _get_statement (self , query , timeout , * , named : bool = False ):
240
+ use_cache = self ._stmt_cache_max_size > 0
241
+ if use_cache :
234
242
try :
235
243
state = self ._stmt_cache [query ]
236
244
except KeyError :
@@ -241,7 +249,13 @@ async def _get_statement(self, query, timeout):
241
249
return state
242
250
243
251
protocol = self ._protocol
244
- state = await protocol .prepare (None , query , timeout )
252
+
253
+ if use_cache or named :
254
+ stmt_name = self ._get_unique_id ('stmt' )
255
+ else :
256
+ stmt_name = ''
257
+
258
+ state = await protocol .prepare (stmt_name , query , timeout )
245
259
246
260
ready = state ._init_types ()
247
261
if ready is not True :
@@ -251,7 +265,7 @@ async def _get_statement(self, query, timeout):
251
265
types = await self ._types_stmt .fetch (list (ready ))
252
266
protocol .get_settings ().register_data_types (types )
253
267
254
- if cache :
268
+ if use_cache :
255
269
if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
256
270
old_query , old_state = self ._stmt_cache .popitem (last = False )
257
271
self ._maybe_gc_stmt (old_state )
@@ -285,7 +299,7 @@ async def prepare(self, query, *, timeout=None):
285
299
286
300
:return: A :class:`~prepared_stmt.PreparedStatement` instance.
287
301
"""
288
- stmt = await self ._get_statement (query , timeout )
302
+ stmt = await self ._get_statement (query , timeout , named = True )
289
303
return prepared_stmt .PreparedStatement (self , query , stmt )
290
304
291
305
async def fetch (self , query , * args , timeout = None ) -> list :
@@ -423,9 +437,9 @@ async def reset(self):
423
437
if reset_query :
424
438
await self .execute (reset_query )
425
439
426
- def _get_unique_id (self ):
440
+ def _get_unique_id (self , prefix ):
427
441
self ._uid += 1
428
- return 'id{} ' .format (self ._uid )
442
+ return '__asyncpg_{}_{}__ ' .format (prefix , self ._uid )
429
443
430
444
def _close_stmts (self ):
431
445
for stmt in self ._stmt_cache .values ():
@@ -450,9 +464,6 @@ async def _cleanup_stmts(self):
450
464
# so we ignore the timeout.
451
465
await self ._protocol .close_statement (stmt , protocol .NO_TIMEOUT )
452
466
453
- def _request_portal_name (self ):
454
- return self ._get_unique_id ()
455
-
456
467
def _cancel_current_command (self , waiter ):
457
468
async def cancel ():
458
469
try :
@@ -572,17 +583,19 @@ def _drop_global_statement_cache(self):
572
583
else :
573
584
self ._drop_local_statement_cache ()
574
585
575
- def _execute (self , query , args , limit , timeout , return_status = False ):
586
+ async def _execute (self , query , args , limit , timeout , return_status = False ):
576
587
executor = lambda stmt , timeout : self ._protocol .bind_execute (
577
588
stmt , args , '' , limit , return_status , timeout )
578
589
timeout = self ._protocol ._get_timeout (timeout )
579
- return self ._do_execute (query , executor , timeout )
590
+ with self ._stmt_exclusive_section :
591
+ return await self ._do_execute (query , executor , timeout )
580
592
581
- def _executemany (self , query , args , timeout ):
593
+ async def _executemany (self , query , args , timeout ):
582
594
executor = lambda stmt , timeout : self ._protocol .bind_execute_many (
583
595
stmt , args , '' , timeout )
584
596
timeout = self ._protocol ._get_timeout (timeout )
585
- return self ._do_execute (query , executor , timeout )
597
+ with self ._stmt_exclusive_section :
598
+ return await self ._do_execute (query , executor , timeout )
586
599
587
600
async def _do_execute (self , query , executor , timeout , retry = True ):
588
601
if timeout is None :
@@ -747,6 +760,22 @@ async def connect(dsn=None, *,
747
760
return con
748
761
749
762
763
+ class _Atomic :
764
+ __slots__ = ('_acquired' ,)
765
+
766
+ def __init__ (self ):
767
+ self ._acquired = 0
768
+
769
+ def __enter__ (self ):
770
+ if self ._acquired :
771
+ raise exceptions .InterfaceError (
772
+ 'cannot perform operation: another operation is in progress' )
773
+ self ._acquired = 1
774
+
775
+ def __exit__ (self , t , e , tb ):
776
+ self ._acquired = 0
777
+
778
+
750
779
class _ConnectionProxy :
751
780
# Base class to enable `isinstance(Connection)` check.
752
781
__slots__ = ()
0 commit comments