Skip to content

Commit a3e9cb4

Browse files
committed
Use anonymous prepared statements when the cache is off.
`Connection.fetch()` and similar methods will now use anonymous prepared statements when the statements cache is off. This makes it possible to use a limited subset of asyncpg functionality with pgbouncer, and optimizes use of PostgreSQL server resources.
1 parent 4fdc1db commit a3e9cb4

File tree

6 files changed

+83
-28
lines changed

6 files changed

+83
-28
lines changed

asyncpg/connection.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
4141
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
4242
'_addr', '_opts', '_command_timeout', '_listeners',
4343
'_server_version', '_server_caps', '_intro_query',
44-
'_reset_query', '_proxy')
44+
'_reset_query', '_proxy', '_stmt_exclusive_section')
4545

4646
def __init__(self, protocol, transport, loop, addr, opts, *,
4747
statement_cache_size, command_timeout):
@@ -97,6 +97,15 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
9797
self._reset_query = None
9898
self._proxy = None
9999

100+
# Used to serialize operations that might involve anonymous
101+
# statements. Specifically, we want to make the following
102+
# operation atomic:
103+
# ("prepare an anonymous statement", "use the statement")
104+
#
105+
# Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`,
106+
# `con.execute()`, and `con.executemany()`.
107+
self._stmt_exclusive_section = _Atomic()
108+
100109
async def add_listener(self, channel, callback):
101110
"""Add a listener for Postgres notifications.
102111
@@ -227,10 +236,9 @@ async def executemany(self, command: str, args, timeout: float=None):
227236
"""
228237
return await self._executemany(command, args, timeout)
229238

230-
async def _get_statement(self, query, timeout):
231-
cache = self._stmt_cache_max_size > 0
232-
233-
if cache:
239+
async def _get_statement(self, query, timeout, *, named: bool=False):
240+
use_cache = self._stmt_cache_max_size > 0
241+
if use_cache:
234242
try:
235243
state = self._stmt_cache[query]
236244
except KeyError:
@@ -241,7 +249,13 @@ async def _get_statement(self, query, timeout):
241249
return state
242250

243251
protocol = self._protocol
244-
state = await protocol.prepare(None, query, timeout)
252+
253+
if use_cache or named:
254+
stmt_name = self._get_unique_id('stmt')
255+
else:
256+
stmt_name = ''
257+
258+
state = await protocol.prepare(stmt_name, query, timeout)
245259

246260
ready = state._init_types()
247261
if ready is not True:
@@ -251,7 +265,7 @@ async def _get_statement(self, query, timeout):
251265
types = await self._types_stmt.fetch(list(ready))
252266
protocol.get_settings().register_data_types(types)
253267

254-
if cache:
268+
if use_cache:
255269
if len(self._stmt_cache) > self._stmt_cache_max_size - 1:
256270
old_query, old_state = self._stmt_cache.popitem(last=False)
257271
self._maybe_gc_stmt(old_state)
@@ -285,7 +299,7 @@ async def prepare(self, query, *, timeout=None):
285299
286300
:return: A :class:`~prepared_stmt.PreparedStatement` instance.
287301
"""
288-
stmt = await self._get_statement(query, timeout)
302+
stmt = await self._get_statement(query, timeout, named=True)
289303
return prepared_stmt.PreparedStatement(self, query, stmt)
290304

291305
async def fetch(self, query, *args, timeout=None) -> list:
@@ -423,9 +437,9 @@ async def reset(self):
423437
if reset_query:
424438
await self.execute(reset_query)
425439

426-
def _get_unique_id(self):
440+
def _get_unique_id(self, prefix):
427441
self._uid += 1
428-
return 'id{}'.format(self._uid)
442+
return '__asyncpg_{}_{}__'.format(prefix, self._uid)
429443

430444
def _close_stmts(self):
431445
for stmt in self._stmt_cache.values():
@@ -450,9 +464,6 @@ async def _cleanup_stmts(self):
450464
# so we ignore the timeout.
451465
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
452466

453-
def _request_portal_name(self):
454-
return self._get_unique_id()
455-
456467
def _cancel_current_command(self, waiter):
457468
async def cancel():
458469
try:
@@ -572,17 +583,19 @@ def _drop_global_statement_cache(self):
572583
else:
573584
self._drop_local_statement_cache()
574585

575-
def _execute(self, query, args, limit, timeout, return_status=False):
586+
async def _execute(self, query, args, limit, timeout, return_status=False):
576587
executor = lambda stmt, timeout: self._protocol.bind_execute(
577588
stmt, args, '', limit, return_status, timeout)
578589
timeout = self._protocol._get_timeout(timeout)
579-
return self._do_execute(query, executor, timeout)
590+
with self._stmt_exclusive_section:
591+
return await self._do_execute(query, executor, timeout)
580592

581-
def _executemany(self, query, args, timeout):
593+
async def _executemany(self, query, args, timeout):
582594
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
583595
stmt, args, '', timeout)
584596
timeout = self._protocol._get_timeout(timeout)
585-
return self._do_execute(query, executor, timeout)
597+
with self._stmt_exclusive_section:
598+
return await self._do_execute(query, executor, timeout)
586599

587600
async def _do_execute(self, query, executor, timeout, retry=True):
588601
if timeout is None:
@@ -747,6 +760,22 @@ async def connect(dsn=None, *,
747760
return con
748761

749762

763+
class _Atomic:
764+
__slots__ = ('_acquired',)
765+
766+
def __init__(self):
767+
self._acquired = 0
768+
769+
def __enter__(self):
770+
if self._acquired:
771+
raise exceptions.InterfaceError(
772+
'cannot perform operation: another operation is in progress')
773+
self._acquired = 1
774+
775+
def __exit__(self, t, e, tb):
776+
self._acquired = 0
777+
778+
750779
class _ConnectionProxy:
751780
# Base class to enable `isinstance(Connection)` check.
752781
__slots__ = ()

asyncpg/cursor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def _bind_exec(self, n, timeout):
9191
con = self._connection
9292
protocol = con._protocol
9393

94-
self._portal_name = con._request_portal_name()
94+
self._portal_name = con._get_unique_id('portal')
9595
buffer, _, self._exhausted = await protocol.bind_execute(
9696
self._state, self._args, self._portal_name, n, True, timeout)
9797
return buffer
@@ -106,7 +106,7 @@ async def _bind(self, timeout):
106106
con = self._connection
107107
protocol = con._protocol
108108

109-
self._portal_name = con._request_portal_name()
109+
self._portal_name = con._get_unique_id('portal')
110110
buffer = await protocol.bind(self._state, self._args,
111111
self._portal_name,
112112
timeout)
@@ -168,7 +168,7 @@ def __aiter__(self):
168168
async def __anext__(self):
169169
if self._state is None:
170170
self._state = await self._connection._get_statement(
171-
self._query, self._timeout)
171+
self._query, self._timeout, named=True)
172172
self._state.attach()
173173

174174
if not self._portal_name:
@@ -193,7 +193,7 @@ class Cursor(BaseCursor):
193193
async def _init(self, timeout):
194194
if self._state is None:
195195
self._state = await self._connection._get_statement(
196-
self._query, timeout)
196+
self._query, timeout, named=True)
197197
self._state.attach()
198198
self._check_ready()
199199
await self._bind(timeout)

asyncpg/protocol/protocol.pxd

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ cdef class BaseProtocol(CoreProtocol):
4040

4141
str last_query
4242

43-
int uid_counter
4443
bint closing
4544

4645
readonly uint64_t queries_count

asyncpg/protocol/protocol.pyx

-5
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ cdef class BaseProtocol(CoreProtocol):
9595
self.settings = ConnectionSettings(
9696
(self.address, con_args.get('database')))
9797

98-
self.uid_counter = 0
9998
self.statement = None
10099
self.return_extra = False
101100

@@ -138,10 +137,6 @@ cdef class BaseProtocol(CoreProtocol):
138137
self._check_state()
139138
timeout = self._get_timeout_impl(timeout)
140139

141-
if stmt_name is None:
142-
self.uid_counter += 1
143-
stmt_name = '__asyncpg_stmt_{}__'.format(self.uid_counter)
144-
145140
self._prepare(stmt_name, query)
146141
self.last_query = query
147142
self.statement = PreparedStatementState(stmt_name, query, self)

asyncpg/transaction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async def start(self):
100100
self._nested = True
101101

102102
if self._nested:
103-
self._id = con._get_unique_id()
103+
self._id = con._get_unique_id('savepoint')
104104
query = 'SAVEPOINT {};'.format(self._id)
105105
else:
106106
if self._isolation == 'read_committed':

tests/test_prepare.py

+32
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,35 @@ async def test_prepare_statement_invalid(self):
424424

425425
finally:
426426
await self.con.execute('DROP TABLE tab1')
427+
428+
async def test_prepare_23_no_stmt_cache_seq(self):
429+
# Disable cache, which will force connections to use
430+
# anonymous prepared statements.
431+
self.con._stmt_cache_max_size = 0
432+
433+
async def check_simple():
434+
# Run a simple query a few times.
435+
self.assertEqual(await self.con.fetchval('SELECT 1'), 1)
436+
self.assertEqual(await self.con.fetchval('SELECT 2'), 2)
437+
self.assertEqual(await self.con.fetchval('SELECT 1'), 1)
438+
439+
await check_simple()
440+
441+
# Run a query that timeouts.
442+
with self.assertRaises(asyncio.TimeoutError):
443+
await self.con.fetchrow('select pg_sleep(10)', timeout=0.02)
444+
445+
# Check that we can run new queries after a timeout.
446+
await check_simple()
447+
448+
# Try a cursor/timeout combination. Cursors should always use
449+
# named prepared statements.
450+
async with self.con.transaction():
451+
with self.assertRaises(asyncio.TimeoutError):
452+
async for _ in self.con.cursor( # NOQA
453+
'select pg_sleep(10)', timeout=0.1):
454+
pass
455+
456+
# Check that we can run queries after a failed cursor
457+
# operation.
458+
await check_simple()

0 commit comments

Comments
 (0)