Skip to content

Commit 537c8c9

Browse files
committed
Shield Pool.release() from task cancellation.
Use asyncio.shield() to guarantee that task cancellation does not prevent the connection from being returned to the pool properly. Fixes: #97.
1 parent d42608f commit 537c8c9

File tree

4 files changed

+83
-14
lines changed

4 files changed

+83
-14
lines changed

asyncpg/_testbase.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,22 @@ def _shutdown_cluster(cluster):
128128
cluster.destroy()
129129

130130

131+
def create_pool(dsn=None, *,
132+
min_size=10,
133+
max_size=10,
134+
max_queries=50000,
135+
setup=None,
136+
init=None,
137+
loop=None,
138+
pool_class=pg_pool.Pool,
139+
**connect_kwargs):
140+
return pool_class(
141+
dsn,
142+
min_size=min_size, max_size=max_size,
143+
max_queries=max_queries, loop=loop, setup=setup, init=init,
144+
**connect_kwargs)
145+
146+
131147
class ClusterTestCase(TestCase):
132148
@classmethod
133149
def setUpClass(cls):
@@ -136,10 +152,10 @@ def setUpClass(cls):
136152
'log_connections': 'on'
137153
})
138154

139-
def create_pool(self, **kwargs):
155+
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
140156
conn_spec = self.cluster.get_connection_spec()
141157
conn_spec.update(kwargs)
142-
return pg_pool.create_pool(loop=self.loop, **conn_spec)
158+
return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec)
143159

144160
@classmethod
145161
def start_cluster(cls, ClusterCls, *,

asyncpg/connection.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ async def connect(dsn=None, *,
526526
timeout=60,
527527
statement_cache_size=100,
528528
command_timeout=None,
529+
connection_class=Connection,
529530
**opts):
530531
"""A coroutine to establish a connection to a PostgreSQL server.
531532
@@ -558,12 +559,16 @@ async def connect(dsn=None, *,
558559
559560
:param float timeout: connection timeout in seconds.
560561
562+
:param int statement_cache_size: the size of prepared statement LRU cache.
563+
561564
:param float command_timeout: the default timeout for operations on
562565
this connection (the default is no timeout).
563566
564-
:param int statement_cache_size: the size of prepared statement LRU cache.
567+
:param builtins.type connection_class: A class used to represent
568+
the connection.
569+
Defaults to :class:`~asyncpg.connection.Connection`.
565570
566-
:return: A :class:`~asyncpg.connection.Connection` instance.
571+
:return: A *connection_class* instance.
567572
568573
Example:
569574
@@ -577,6 +582,10 @@ async def connect(dsn=None, *,
577582
... print(types)
578583
>>> asyncio.get_event_loop().run_until_complete(run())
579584
[<Record typname='bool' typnamespace=11 ...
585+
586+
587+
.. versionadded:: 0.10.0
588+
*connection_class* argument.
580589
"""
581590
if loop is None:
582591
loop = asyncio.get_event_loop()
@@ -620,9 +629,9 @@ async def connect(dsn=None, *,
620629
tr.close()
621630
raise
622631

623-
con = Connection(pr, tr, loop, addr, opts,
624-
statement_cache_size=statement_cache_size,
625-
command_timeout=command_timeout)
632+
con = connection_class(pr, tr, loop, addr, opts,
633+
statement_cache_size=statement_cache_size,
634+
command_timeout=command_timeout)
626635
pr.set_connection(con)
627636
return con
628637

asyncpg/pool.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ def __init__(self, *connect_args,
7070

7171
self._closed = False
7272

73+
async def _connect(self, *args, **kwargs):
74+
return await connection.connect(*args, **kwargs)
75+
7376
async def _new_connection(self):
7477
if self._working_addr is None:
75-
con = await connection.connect(*self._connect_args,
76-
loop=self._loop,
77-
**self._connect_kwargs)
78-
78+
con = await self._connect(*self._connect_args,
79+
loop=self._loop,
80+
**self._connect_kwargs)
7981
self._working_addr = con._addr
8082
self._working_opts = con._opts
8183

@@ -86,9 +88,9 @@ async def _new_connection(self):
8688
else:
8789
host, port = self._working_addr
8890

89-
con = await connection.connect(host=host, port=port,
90-
loop=self._loop,
91-
**self._working_opts)
91+
con = await self._connect(host=host, port=port,
92+
loop=self._loop,
93+
**self._working_opts)
9294

9395
if self._init is not None:
9496
await self._init(con)
@@ -177,6 +179,13 @@ async def _acquire_impl(self):
177179

178180
async def release(self, connection):
179181
"""Release a database connection back to the pool."""
182+
# Use asyncio.shield() to guarantee that task cancellation
183+
# does not prevent the connection from being returned to the
184+
# pool properly.
185+
return await asyncio.shield(self._release_impl(connection),
186+
loop=self._loop)
187+
188+
async def _release_impl(self, connection):
180189
self._check_init()
181190
if connection.is_closed():
182191
self._con_count -= 1

tests/test_pool.py

+35
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import unittest
1212

1313
from asyncpg import _testbase as tb
14+
from asyncpg import connection as pg_connection
1415
from asyncpg import cluster as pg_cluster
1516
from asyncpg import pool as pg_pool
1617

@@ -24,6 +25,19 @@
2425
POOL_NOMINAL_TIMEOUT = 0.1
2526

2627

28+
class SlowResetConnection(pg_connection.Connection):
29+
"""Connection class to simulate races with Connection.reset()."""
30+
async def reset(self):
31+
await asyncio.sleep(0.2, loop=self._loop)
32+
return await super().reset()
33+
34+
35+
class SlowResetConnectionPool(pg_pool.Pool):
36+
async def _connect(self, *args, **kwargs):
37+
return await pg_connection.connect(
38+
*args, connection_class=SlowResetConnection, **kwargs)
39+
40+
2741
class TestPool(tb.ConnectedTestCase):
2842

2943
async def test_pool_01(self):
@@ -186,6 +200,27 @@ async def worker():
186200
self.cluster.trust_local_connections()
187201
self.cluster.reload()
188202

203+
async def test_pool_handles_cancel_in_release(self):
204+
# Use SlowResetConnectionPool to simulate
205+
# the Task.cancel() and __aexit__ race.
206+
pool = await self.create_pool(database='postgres',
207+
min_size=1, max_size=1,
208+
pool_class=SlowResetConnectionPool)
209+
210+
async def worker():
211+
async with pool.acquire():
212+
pass
213+
214+
task = self.loop.create_task(worker())
215+
# Let the worker() run.
216+
await asyncio.sleep(0.1, loop=self.loop)
217+
# Cancel the worker.
218+
task.cancel()
219+
# Wait to make sure the cleanup has completed.
220+
await asyncio.sleep(0.4, loop=self.loop)
221+
# Check that the connection has been returned to the pool.
222+
self.assertEqual(pool._queue.qsize(), 1)
223+
189224

190225
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
191226
class TestHostStandby(tb.ConnectedTestCase):

0 commit comments

Comments
 (0)