11
11
import collections .abc
12
12
import functools
13
13
import itertools
14
+ import inspect
14
15
import os
15
16
import sys
16
17
import time
17
18
import traceback
19
+ import typing
18
20
import warnings
19
21
import weakref
20
22
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133
135
:param str channel: Channel to listen on.
134
136
135
137
:param callable callback:
136
- A callable receiving the following arguments:
138
+ A callable or a coroutine function receiving the following
139
+ arguments:
137
140
**connection**: a Connection the callback is registered with;
138
141
**pid**: PID of the Postgres server that sent the notification;
139
142
**channel**: name of the channel the notification was sent to;
140
143
**payload**: the payload.
144
+
145
+ .. versionchanged:: 0.24.0
146
+ The ``callback`` argument may be a coroutine function.
141
147
"""
142
148
self ._check_open ()
143
149
if channel not in self ._listeners :
144
150
await self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145
151
self ._listeners [channel ] = set ()
146
- self ._listeners [channel ].add (callback )
152
+ self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147
153
148
154
async def remove_listener (self , channel , callback ):
149
155
"""Remove a listening callback on the specified channel."""
150
156
if self .is_closed ():
151
157
return
152
158
if channel not in self ._listeners :
153
159
return
154
- if callback not in self ._listeners [channel ]:
160
+ cb = _Callback .from_callable (callback )
161
+ if cb not in self ._listeners [channel ]:
155
162
return
156
- self ._listeners [channel ].remove (callback )
163
+ self ._listeners [channel ].remove (cb )
157
164
if not self ._listeners [channel ]:
158
165
del self ._listeners [channel ]
159
166
await self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166
173
DEBUG, INFO, or LOG.
167
174
168
175
:param callable callback:
169
- A callable receiving the following arguments:
176
+ A callable or a coroutine function receiving the following
177
+ arguments:
170
178
**connection**: a Connection the callback is registered with;
171
179
**message**: the `exceptions.PostgresLogMessage` message.
172
180
173
181
.. versionadded:: 0.12.0
182
+
183
+ .. versionchanged:: 0.24.0
184
+ The ``callback`` argument may be a coroutine function.
174
185
"""
175
186
if self .is_closed ():
176
187
raise exceptions .InterfaceError ('connection is closed' )
177
- self ._log_listeners .add (callback )
188
+ self ._log_listeners .add (_Callback . from_callable ( callback ) )
178
189
179
190
def remove_log_listener (self , callback ):
180
191
"""Remove a listening callback for log messages.
181
192
182
193
.. versionadded:: 0.12.0
183
194
"""
184
- self ._log_listeners .discard (callback )
195
+ self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185
196
186
197
def add_termination_listener (self , callback ):
187
198
"""Add a listener that will be called when the connection is closed.
188
199
189
200
:param callable callback:
190
- A callable receiving one argument:
201
+ A callable or a coroutine function receiving one argument:
191
202
**connection**: a Connection the callback is registered with.
192
203
193
204
.. versionadded:: 0.21.0
205
+
206
+ .. versionchanged:: 0.24.0
207
+ The ``callback`` argument may be a coroutine function.
194
208
"""
195
- self ._termination_listeners .add (callback )
209
+ self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196
210
197
211
def remove_termination_listener (self , callback ):
198
212
"""Remove a listening callback for connection termination.
199
213
200
214
:param callable callback:
201
- The callable that was passed to
215
+ The callable or coroutine function that was passed to
202
216
:meth:`Connection.add_termination_listener`.
203
217
204
218
.. versionadded:: 0.21.0
205
219
"""
206
- self ._termination_listeners .discard (callback )
220
+ self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207
221
208
222
def get_server_pid (self ):
209
223
"""Return the PID of the Postgres server the connection is bound to."""
@@ -1449,35 +1463,21 @@ def _process_log_message(self, fields, last_query):
1449
1463
1450
1464
con_ref = self ._unwrap ()
1451
1465
for cb in self ._log_listeners :
1452
- self ._loop .call_soon (
1453
- self ._call_log_listener , cb , con_ref , message )
1454
-
1455
- def _call_log_listener (self , cb , con_ref , message ):
1456
- try :
1457
- cb (con_ref , message )
1458
- except Exception as ex :
1459
- self ._loop .call_exception_handler ({
1460
- 'message' : 'Unhandled exception in asyncpg log message '
1461
- 'listener callback {!r}' .format (cb ),
1462
- 'exception' : ex
1463
- })
1466
+ if cb .is_async :
1467
+ self ._loop .create_task (cb .cb (con_ref , message ))
1468
+ else :
1469
+ self ._loop .call_soon (cb .cb , con_ref , message )
1464
1470
1465
1471
def _call_termination_listeners (self ):
1466
1472
if not self ._termination_listeners :
1467
1473
return
1468
1474
1469
1475
con_ref = self ._unwrap ()
1470
1476
for cb in self ._termination_listeners :
1471
- try :
1472
- cb (con_ref )
1473
- except Exception as ex :
1474
- self ._loop .call_exception_handler ({
1475
- 'message' : (
1476
- 'Unhandled exception in asyncpg connection '
1477
- 'termination listener callback {!r}' .format (cb )
1478
- ),
1479
- 'exception' : ex
1480
- })
1477
+ if cb .is_async :
1478
+ self ._loop .create_task (cb .cb (con_ref ))
1479
+ else :
1480
+ self ._loop .call_soon (cb .cb , con_ref )
1481
1481
1482
1482
self ._termination_listeners .clear ()
1483
1483
@@ -1487,18 +1487,10 @@ def _process_notification(self, pid, channel, payload):
1487
1487
1488
1488
con_ref = self ._unwrap ()
1489
1489
for cb in self ._listeners [channel ]:
1490
- self ._loop .call_soon (
1491
- self ._call_listener , cb , con_ref , pid , channel , payload )
1492
-
1493
- def _call_listener (self , cb , con_ref , pid , channel , payload ):
1494
- try :
1495
- cb (con_ref , pid , channel , payload )
1496
- except Exception as ex :
1497
- self ._loop .call_exception_handler ({
1498
- 'message' : 'Unhandled exception in asyncpg notification '
1499
- 'listener callback {!r}' .format (cb ),
1500
- 'exception' : ex
1501
- })
1490
+ if cb .is_async :
1491
+ self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1492
+ else :
1493
+ self ._loop .call_soon (cb .cb , con_ref , pid , channel , payload )
1502
1494
1503
1495
def _unwrap (self ):
1504
1496
if self ._proxy is None :
@@ -2173,6 +2165,26 @@ def _maybe_cleanup(self):
2173
2165
self ._on_remove (old_entry ._statement )
2174
2166
2175
2167
2168
+ class _Callback (typing .NamedTuple ):
2169
+
2170
+ cb : typing .Callable [..., None ]
2171
+ is_async : bool
2172
+
2173
+ @classmethod
2174
+ def from_callable (cls , cb : typing .Callable [..., None ]) -> '_Callback' :
2175
+ if inspect .iscoroutinefunction (cb ):
2176
+ is_async = True
2177
+ elif callable (cb ):
2178
+ is_async = False
2179
+ else :
2180
+ raise exceptions .InterfaceError (
2181
+ 'expected a callable or an `async def` function,'
2182
+ 'got {!r}' .format (cb )
2183
+ )
2184
+
2185
+ return cls (cb , is_async )
2186
+
2187
+
2176
2188
class _Atomic :
2177
2189
__slots__ = ('_acquired' ,)
2178
2190
0 commit comments