Skip to content

Commit 41da093

Browse files
authored
Add support for coroutine functions as listener callbacks (#802)
The `Connection.add_listener()`, `Connection.add_log_listener()` and `Connection.add_termination_listener()` now allow coroutine functions as callbacks. Fixes: #567.
1 parent 1d33ff6 commit 41da093

File tree

2 files changed

+96
-45
lines changed

2 files changed

+96
-45
lines changed

asyncpg/connection.py

+57-45
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import collections.abc
1212
import functools
1313
import itertools
14+
import inspect
1415
import os
1516
import sys
1617
import time
1718
import traceback
19+
import typing
1820
import warnings
1921
import weakref
2022

@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133135
:param str channel: Channel to listen on.
134136
135137
:param callable callback:
136-
A callable receiving the following arguments:
138+
A callable or a coroutine function receiving the following
139+
arguments:
137140
**connection**: a Connection the callback is registered with;
138141
**pid**: PID of the Postgres server that sent the notification;
139142
**channel**: name of the channel the notification was sent to;
140143
**payload**: the payload.
144+
145+
.. versionchanged:: 0.24.0
146+
The ``callback`` argument may be a coroutine function.
141147
"""
142148
self._check_open()
143149
if channel not in self._listeners:
144150
await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
145151
self._listeners[channel] = set()
146-
self._listeners[channel].add(callback)
152+
self._listeners[channel].add(_Callback.from_callable(callback))
147153

148154
async def remove_listener(self, channel, callback):
149155
"""Remove a listening callback on the specified channel."""
150156
if self.is_closed():
151157
return
152158
if channel not in self._listeners:
153159
return
154-
if callback not in self._listeners[channel]:
160+
cb = _Callback.from_callable(callback)
161+
if cb not in self._listeners[channel]:
155162
return
156-
self._listeners[channel].remove(callback)
163+
self._listeners[channel].remove(cb)
157164
if not self._listeners[channel]:
158165
del self._listeners[channel]
159166
await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel)))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166173
DEBUG, INFO, or LOG.
167174
168175
:param callable callback:
169-
A callable receiving the following arguments:
176+
A callable or a coroutine function receiving the following
177+
arguments:
170178
**connection**: a Connection the callback is registered with;
171179
**message**: the `exceptions.PostgresLogMessage` message.
172180
173181
.. versionadded:: 0.12.0
182+
183+
.. versionchanged:: 0.24.0
184+
The ``callback`` argument may be a coroutine function.
174185
"""
175186
if self.is_closed():
176187
raise exceptions.InterfaceError('connection is closed')
177-
self._log_listeners.add(callback)
188+
self._log_listeners.add(_Callback.from_callable(callback))
178189

179190
def remove_log_listener(self, callback):
180191
"""Remove a listening callback for log messages.
181192
182193
.. versionadded:: 0.12.0
183194
"""
184-
self._log_listeners.discard(callback)
195+
self._log_listeners.discard(_Callback.from_callable(callback))
185196

186197
def add_termination_listener(self, callback):
187198
"""Add a listener that will be called when the connection is closed.
188199
189200
:param callable callback:
190-
A callable receiving one argument:
201+
A callable or a coroutine function receiving one argument:
191202
**connection**: a Connection the callback is registered with.
192203
193204
.. versionadded:: 0.21.0
205+
206+
.. versionchanged:: 0.24.0
207+
The ``callback`` argument may be a coroutine function.
194208
"""
195-
self._termination_listeners.add(callback)
209+
self._termination_listeners.add(_Callback.from_callable(callback))
196210

197211
def remove_termination_listener(self, callback):
198212
"""Remove a listening callback for connection termination.
199213
200214
:param callable callback:
201-
The callable that was passed to
215+
The callable or coroutine function that was passed to
202216
:meth:`Connection.add_termination_listener`.
203217
204218
.. versionadded:: 0.21.0
205219
"""
206-
self._termination_listeners.discard(callback)
220+
self._termination_listeners.discard(_Callback.from_callable(callback))
207221

208222
def get_server_pid(self):
209223
"""Return the PID of the Postgres server the connection is bound to."""
@@ -1449,35 +1463,21 @@ def _process_log_message(self, fields, last_query):
14491463

14501464
con_ref = self._unwrap()
14511465
for cb in self._log_listeners:
1452-
self._loop.call_soon(
1453-
self._call_log_listener, cb, con_ref, message)
1454-
1455-
def _call_log_listener(self, cb, con_ref, message):
1456-
try:
1457-
cb(con_ref, message)
1458-
except Exception as ex:
1459-
self._loop.call_exception_handler({
1460-
'message': 'Unhandled exception in asyncpg log message '
1461-
'listener callback {!r}'.format(cb),
1462-
'exception': ex
1463-
})
1466+
if cb.is_async:
1467+
self._loop.create_task(cb.cb(con_ref, message))
1468+
else:
1469+
self._loop.call_soon(cb.cb, con_ref, message)
14641470

14651471
def _call_termination_listeners(self):
14661472
if not self._termination_listeners:
14671473
return
14681474

14691475
con_ref = self._unwrap()
14701476
for cb in self._termination_listeners:
1471-
try:
1472-
cb(con_ref)
1473-
except Exception as ex:
1474-
self._loop.call_exception_handler({
1475-
'message': (
1476-
'Unhandled exception in asyncpg connection '
1477-
'termination listener callback {!r}'.format(cb)
1478-
),
1479-
'exception': ex
1480-
})
1477+
if cb.is_async:
1478+
self._loop.create_task(cb.cb(con_ref))
1479+
else:
1480+
self._loop.call_soon(cb.cb, con_ref)
14811481

14821482
self._termination_listeners.clear()
14831483

@@ -1487,18 +1487,10 @@ def _process_notification(self, pid, channel, payload):
14871487

14881488
con_ref = self._unwrap()
14891489
for cb in self._listeners[channel]:
1490-
self._loop.call_soon(
1491-
self._call_listener, cb, con_ref, pid, channel, payload)
1492-
1493-
def _call_listener(self, cb, con_ref, pid, channel, payload):
1494-
try:
1495-
cb(con_ref, pid, channel, payload)
1496-
except Exception as ex:
1497-
self._loop.call_exception_handler({
1498-
'message': 'Unhandled exception in asyncpg notification '
1499-
'listener callback {!r}'.format(cb),
1500-
'exception': ex
1501-
})
1490+
if cb.is_async:
1491+
self._loop.create_task(cb.cb(con_ref, pid, channel, payload))
1492+
else:
1493+
self._loop.call_soon(cb.cb, con_ref, pid, channel, payload)
15021494

15031495
def _unwrap(self):
15041496
if self._proxy is None:
@@ -2173,6 +2165,26 @@ def _maybe_cleanup(self):
21732165
self._on_remove(old_entry._statement)
21742166

21752167

2168+
class _Callback(typing.NamedTuple):
2169+
2170+
cb: typing.Callable[..., None]
2171+
is_async: bool
2172+
2173+
@classmethod
2174+
def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback':
2175+
if inspect.iscoroutinefunction(cb):
2176+
is_async = True
2177+
elif callable(cb):
2178+
is_async = False
2179+
else:
2180+
raise exceptions.InterfaceError(
2181+
'expected a callable or an `async def` function,'
2182+
'got {!r}'.format(cb)
2183+
)
2184+
2185+
return cls(cb, is_async)
2186+
2187+
21762188
class _Atomic:
21772189
__slots__ = ('_acquired',)
21782190

tests/test_listeners.py

+39
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,20 @@ async def test_listen_01(self):
2323

2424
q1 = asyncio.Queue()
2525
q2 = asyncio.Queue()
26+
q3 = asyncio.Queue()
2627

2728
def listener1(*args):
2829
q1.put_nowait(args)
2930

3031
def listener2(*args):
3132
q2.put_nowait(args)
3233

34+
async def async_listener3(*args):
35+
q3.put_nowait(args)
36+
3337
await con.add_listener('test', listener1)
3438
await con.add_listener('test', listener2)
39+
await con.add_listener('test', async_listener3)
3540

3641
await con.execute("NOTIFY test, 'aaaa'")
3742

@@ -41,8 +46,12 @@ def listener2(*args):
4146
self.assertEqual(
4247
await q2.get(),
4348
(con, con.get_server_pid(), 'test', 'aaaa'))
49+
self.assertEqual(
50+
await q3.get(),
51+
(con, con.get_server_pid(), 'test', 'aaaa'))
4452

4553
await con.remove_listener('test', listener2)
54+
await con.remove_listener('test', async_listener3)
4655

4756
await con.execute("NOTIFY test, 'aaaa'")
4857

@@ -117,13 +126,20 @@ class TestLogListeners(tb.ConnectedTestCase):
117126
})
118127
async def test_log_listener_01(self):
119128
q1 = asyncio.Queue()
129+
q2 = asyncio.Queue()
120130

121131
def notice_callb(con, message):
122132
# Message fields depend on PG version, hide some values.
123133
dct = message.as_dict()
124134
del dct['server_source_line']
125135
q1.put_nowait((con, type(message), dct))
126136

137+
async def async_notice_callb(con, message):
138+
# Message fields depend on PG version, hide some values.
139+
dct = message.as_dict()
140+
del dct['server_source_line']
141+
q2.put_nowait((con, type(message), dct))
142+
127143
async def raise_notice():
128144
await self.con.execute(
129145
"""DO $$
@@ -140,6 +156,7 @@ async def raise_warning():
140156

141157
con = self.con
142158
con.add_log_listener(notice_callb)
159+
con.add_log_listener(async_notice_callb)
143160

144161
expected_msg = {
145162
'context': 'PL/pgSQL function inline_code_block line 2 at RAISE',
@@ -182,7 +199,21 @@ async def raise_warning():
182199
msg,
183200
(con, exceptions.PostgresWarning, expected_msg_warn))
184201

202+
msg = await q2.get()
203+
msg[2].pop('server_source_filename', None)
204+
self.assertEqual(
205+
msg,
206+
(con, exceptions.PostgresLogMessage, expected_msg_notice))
207+
208+
msg = await q2.get()
209+
msg[2].pop('server_source_filename', None)
210+
self.assertEqual(
211+
msg,
212+
(con, exceptions.PostgresWarning, expected_msg_warn))
213+
185214
con.remove_log_listener(notice_callb)
215+
con.remove_log_listener(async_notice_callb)
216+
186217
await raise_notice()
187218
self.assertTrue(q1.empty())
188219

@@ -291,19 +322,26 @@ class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):
291322
async def test_connection_termination_callback_called_on_remote(self):
292323

293324
called = False
325+
async_called = False
294326

295327
def close_cb(con):
296328
nonlocal called
297329
called = True
298330

331+
async def async_close_cb(con):
332+
nonlocal async_called
333+
async_called = True
334+
299335
con = await self.connect()
300336
con.add_termination_listener(close_cb)
337+
con.add_termination_listener(async_close_cb)
301338
self.proxy.close_all_connections()
302339
try:
303340
await con.fetchval('SELECT 1')
304341
except Exception:
305342
pass
306343
self.assertTrue(called)
344+
self.assertTrue(async_called)
307345

308346
async def test_connection_termination_callback_called_on_local(self):
309347

@@ -316,4 +354,5 @@ def close_cb(con):
316354
con = await self.connect()
317355
con.add_termination_listener(close_cb)
318356
await con.close()
357+
await asyncio.sleep(0)
319358
self.assertTrue(called)

0 commit comments

Comments
 (0)