Skip to content

Commit 4a160e9

Browse files
committed
Add 'init' parameter to create_pool(). Addresses issue #80.
1 parent 005d339 commit 4a160e9

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

asyncpg/connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def transaction(self, *, isolation='read_committed', readonly=False,
134134
:param deferrable: Specifies whether or not this transaction is
135135
deferrable.
136136
137-
.. _`PostgreSQL documentation`: https://www.postgresql.org/docs/current/static/sql-set-transaction.html
137+
.. _`PostgreSQL documentation`: https://www.postgresql.org/docs/\
138+
current/static/sql-set-transaction.html
138139
"""
139140
return transaction.Transaction(self, isolation, readonly, deferrable)
140141

asyncpg/pool.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ class Pool:
2626
'_connect_args', '_connect_kwargs',
2727
'_working_addr', '_working_opts',
2828
'_con_count', '_max_queries', '_connections',
29-
'_initialized', '_closed', '_setup')
29+
'_initialized', '_closed', '_setup', '_init')
3030

3131
def __init__(self, *connect_args,
3232
min_size,
3333
max_size,
3434
max_queries,
3535
setup,
36+
init,
3637
loop,
3738
**connect_kwargs):
3839

@@ -57,6 +58,7 @@ def __init__(self, *connect_args,
5758
self._max_queries = max_queries
5859

5960
self._setup = setup
61+
self._init = init
6062

6163
self._connect_args = connect_args
6264
self._connect_kwargs = connect_kwargs
@@ -88,10 +90,13 @@ async def _new_connection(self):
8890
loop=self._loop,
8991
**self._working_opts)
9092

93+
if self._init is not None:
94+
await self._init(con)
95+
9196
self._connections.add(con)
9297
return con
9398

94-
async def _init(self):
99+
async def _initialize(self):
95100
if self._initialized:
96101
return
97102
if self._closed:
@@ -219,10 +224,10 @@ def _reset(self):
219224
self._queue = asyncio.Queue(maxsize=self._maxsize, loop=self._loop)
220225

221226
def __await__(self):
222-
return self._init().__await__()
227+
return self._initialize().__await__()
223228

224229
async def __aenter__(self):
225-
await self._init()
230+
await self._initialize()
226231
return self
227232

228233
async def __aexit__(self, *exc):
@@ -261,6 +266,7 @@ def create_pool(dsn=None, *,
261266
max_size=10,
262267
max_queries=50000,
263268
setup=None,
269+
init=None,
264270
loop=None,
265271
**connect_kwargs):
266272
r"""Create a connection pool.
@@ -296,16 +302,22 @@ def create_pool(dsn=None, *,
296302
:param int max_size: Max number of connections in the pool.
297303
:param int max_queries: Number of queries after a connection is closed
298304
and replaced with a new connection.
299-
:param coroutine setup: A coroutine to initialize a connection right before
305+
:param coroutine setup: A coroutine to prepare a connection right before
300306
it is returned from :meth:`~pool.Pool.acquire`.
301307
An example use case would be to automatically
302308
set up notifications listeners for all connections
303309
of a pool.
310+
:param coroutine init: A coroutine to initialize a connection when it
311+
is created. An example use case would be to setup
312+
type codecs with
313+
:meth:`~asyncpg.connection.Connection.\
314+
set_builtin_type_codec` or :meth:`~asyncpg.\
315+
connection.Connection.set_type_codec`.
304316
:param loop: An asyncio event loop instance. If ``None``, the default
305317
event loop will be used.
306318
:return: An instance of :class:`~asyncpg.pool.Pool`.
307319
"""
308320
return Pool(dsn,
309321
min_size=min_size, max_size=max_size,
310-
max_queries=max_queries, loop=loop, setup=setup,
322+
max_queries=max_queries, loop=loop, setup=setup, init=init,
311323
**connect_kwargs)

tests/test_pool.py

+27
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,33 @@ async def setup(con):
102102

103103
self.assertIs(con, await fut)
104104

105+
async def test_pool_07(self):
106+
cons = set()
107+
108+
async def setup(con):
109+
if con not in cons:
110+
raise RuntimeError('init was not called before setup')
111+
112+
async def init(con):
113+
if con in cons:
114+
raise RuntimeError('init was called more than once')
115+
cons.add(con)
116+
117+
async def user(pool):
118+
async with pool.acquire() as con:
119+
if con not in cons:
120+
raise RuntimeError('init was not called')
121+
122+
async with self.create_pool(database='postgres',
123+
min_size=2, max_size=5,
124+
init=init,
125+
setup=setup) as pool:
126+
users = asyncio.gather(*[user(pool) for _ in range(10)],
127+
loop=self.loop)
128+
await users
129+
130+
self.assertEqual(len(cons), 5)
131+
105132
async def test_pool_auth(self):
106133
if not self.cluster.is_managed():
107134
self.skipTest('unmanaged cluster')

0 commit comments

Comments
 (0)