Skip to content

Commit 1b1893d

Browse files
vitaly-burovoy1st1
authored andcommitted
Add API for receiving asynchronous notices
Notice message types are: WARNING, NOTICE, DEBUG, INFO, or LOG. https://www.postgresql.org/docs/current/static/protocol-error-fields.html Issue #144.
1 parent 90725f1 commit 1b1893d

File tree

6 files changed

+170
-4
lines changed

6 files changed

+170
-4
lines changed

asyncpg/connection.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
4141
'_stmt_cache', '_stmts_to_close', '_listeners',
4242
'_server_version', '_server_caps', '_intro_query',
4343
'_reset_query', '_proxy', '_stmt_exclusive_section',
44-
'_config', '_params', '_addr')
44+
'_config', '_params', '_addr', '_notice_callbacks')
4545

4646
def __init__(self, protocol, transport, loop,
4747
addr: (str, int) or str,
@@ -69,6 +69,7 @@ def __init__(self, protocol, transport, loop,
6969
self._stmts_to_close = set()
7070

7171
self._listeners = {}
72+
self._notice_callbacks = set()
7273

7374
settings = self._protocol.get_settings()
7475
ver_string = settings.server_version
@@ -126,6 +127,26 @@ async def remove_listener(self, channel, callback):
126127
del self._listeners[channel]
127128
await self.fetch('UNLISTEN {}'.format(channel))
128129

130+
def add_notice_callback(self, callback):
131+
"""Add a callback for Postgres notices (NOTICE, DEBUG, LOG etc.).
132+
133+
It will be called when asyncronous NoticeResponse is received
134+
from the connection. Possible message types are: WARNING, NOTICE, DEBUG,
135+
INFO, or LOG.
136+
137+
:param callable callback:
138+
A callable receiving the following arguments:
139+
**connection**: a Connection the callback is registered with;
140+
**message**: the `exceptions.PostgresNotice` message.
141+
"""
142+
if self.is_closed():
143+
raise exceptions.InterfaceError('connection is closed')
144+
self._notice_callbacks.add(callback)
145+
146+
def remove_notice_callback(self, callback):
147+
"""Remove a callback for notices."""
148+
self._notice_callbacks.discard(callback)
149+
129150
def get_server_pid(self):
130151
"""Return the PID of the Postgres server the connection is bound to."""
131152
return self._protocol.get_server_pid()
@@ -821,13 +842,15 @@ async def close(self):
821842
self._listeners = {}
822843
self._aborted = True
823844
await self._protocol.close()
845+
self._notice_callbacks = set()
824846

825847
def terminate(self):
826848
"""Terminate the connection without waiting for pending data."""
827849
self._mark_stmts_as_closed()
828850
self._listeners = {}
829851
self._aborted = True
830852
self._protocol.abort()
853+
self._notice_callbacks = set()
831854

832855
async def reset(self):
833856
self._check_open()
@@ -909,6 +932,26 @@ async def cancel():
909932

910933
self._loop.create_task(cancel())
911934

935+
def _notice(self, message):
936+
if self._proxy is None:
937+
con_ref = self
938+
else:
939+
# See the comment in the `_notify` below.
940+
con_ref = self._proxy
941+
942+
for cb in self._notice_callbacks:
943+
self._loop.call_soon(self._call_notice_cb, cb, con_ref, message)
944+
945+
def _call_notice_cb(self, cb, con_ref, message):
946+
try:
947+
cb(con_ref, message)
948+
except Exception as ex:
949+
self._loop.call_exception_handler({
950+
'message': 'Unhandled exception in asyncpg notice message '
951+
'callback {!r}'.format(cb),
952+
'exception': ex
953+
})
954+
912955
def _notify(self, pid, channel, payload):
913956
if channel not in self._listeners:
914957
return

asyncpg/exceptions/_base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
12-
'InterfaceError')
12+
'InterfaceError', 'PostgresNotice')
1313

1414

1515
def _is_asyncpg_class(cls):
@@ -151,3 +151,10 @@ class UnknownPostgresError(FatalPostgresError):
151151

152152
class InterfaceError(Exception):
153153
"""An error caused by improper use of asyncpg API."""
154+
155+
156+
class PostgresNotice(PostgresMessage):
157+
sqlstate = '00000'
158+
159+
def __init__(self, message):
160+
self.args = [message]

asyncpg/protocol/coreproto.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,6 @@ cdef class CoreProtocol:
170170

171171
cdef _on_result(self)
172172
cdef _on_notification(self, pid, channel, payload)
173+
cdef _on_notice(self, parsed)
173174
cdef _set_server_parameter(self, name, val)
174175
cdef _on_connection_lost(self, exc)

asyncpg/protocol/coreproto.pyx

+10-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ cdef class CoreProtocol:
5656
# NotificationResponse
5757
self._parse_msg_notification()
5858
continue
59+
elif mtype == b'N':
60+
# 'N' - NoticeResponse
61+
self._on_notice(self._parse_msg_error_response(False))
62+
continue
5963

6064
if state == PROTOCOL_AUTH:
6165
self._process__auth(mtype)
@@ -302,10 +306,9 @@ cdef class CoreProtocol:
302306
self._push_result()
303307

304308
cdef _process__simple_query(self, char mtype):
305-
if mtype in {b'D', b'I', b'N', b'T'}:
309+
if mtype in {b'D', b'I', b'T'}:
306310
# 'D' - DataRow
307311
# 'I' - EmptyQueryResponse
308-
# 'N' - NoticeResponse
309312
# 'T' - RowDescription
310313
self.buffer.consume_message()
311314

@@ -614,6 +617,8 @@ cdef class CoreProtocol:
614617
if is_error:
615618
self.result_type = RESULT_FAILED
616619
self.result = parsed
620+
else:
621+
return parsed
617622

618623
cdef _push_result(self):
619624
try:
@@ -910,6 +915,9 @@ cdef class CoreProtocol:
910915
cdef _on_result(self):
911916
pass
912917

918+
cdef _on_notice(self, parsed):
919+
pass
920+
913921
cdef _on_notification(self, pid, channel, payload):
914922
pass
915923

asyncpg/protocol/protocol.pyx

+8
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,14 @@ cdef class BaseProtocol(CoreProtocol):
731731
self.last_query = None
732732
self.return_extra = False
733733

734+
cdef _on_notice(self, parsed):
735+
# Check it here to avoid unnecessary object creation.
736+
if self.connection._notice_callbacks:
737+
message = apg_exc_base.PostgresMessage.new(
738+
parsed, query=self.last_query)
739+
740+
self.connection._notice(message)
741+
734742
cdef _on_notification(self, pid, channel, payload):
735743
self.connection._notify(pid, channel, payload)
736744

tests/test_listeners.py

+99
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99

1010
from asyncpg import _testbase as tb
11+
from asyncpg.exceptions import PostgresNotice, PostgresWarning
1112

1213

1314
class TestListeners(tb.ClusterTestCase):
@@ -74,3 +75,101 @@ def listener1(*args):
7475
self.assertEqual(
7576
await q1.get(),
7677
(con1, con2.get_server_pid(), 'ipc', 'hello'))
78+
79+
80+
class TestNotices(tb.ConnectedTestCase):
81+
async def test_notify_01(self):
82+
q1 = asyncio.Queue(loop=self.loop)
83+
84+
def notice_callb(con, message):
85+
# data in the message depend on PG's version, hide some values
86+
if message.server_source_line is not None :
87+
message.server_source_line = '***'
88+
89+
q1.put_nowait((con, type(message), message.as_dict()))
90+
91+
con = self.con
92+
con.add_notice_callback(notice_callb)
93+
await con.execute(
94+
"DO $$ BEGIN RAISE NOTICE 'catch me!'; END; $$ LANGUAGE plpgsql"
95+
)
96+
await con.execute(
97+
"DO $$ BEGIN RAISE WARNING 'catch me!'; END; $$ LANGUAGE plpgsql"
98+
)
99+
100+
expect_msg = {
101+
'context': 'PL/pgSQL function inline_code_block line 1 at RAISE',
102+
'message': 'catch me!',
103+
'server_source_filename': 'pl_exec.c',
104+
'server_source_function': 'exec_stmt_raise',
105+
'server_source_line': '***'}
106+
107+
expect_msg_notice = expect_msg.copy()
108+
expect_msg_notice.update({
109+
'severity': 'NOTICE',
110+
'severity_en': 'NOTICE',
111+
'sqlstate': '00000',
112+
})
113+
114+
expect_msg_warn = expect_msg.copy()
115+
expect_msg_warn.update({
116+
'severity': 'WARNING',
117+
'severity_en': 'WARNING',
118+
'sqlstate': '01000',
119+
})
120+
121+
if con.get_server_version() < (9, 6):
122+
del expect_msg_notice['context']
123+
del expect_msg_notice['severity_en']
124+
del expect_msg_warn['context']
125+
del expect_msg_warn['severity_en']
126+
127+
self.assertEqual(
128+
await q1.get(),
129+
(con, PostgresNotice, expect_msg_notice))
130+
131+
self.assertEqual(
132+
await q1.get(),
133+
(con, PostgresWarning, expect_msg_warn))
134+
135+
con.remove_notice_callback(notice_callb)
136+
await con.execute(
137+
"DO $$ BEGIN RAISE NOTICE '/dev/null!'; END; $$ LANGUAGE plpgsql"
138+
)
139+
140+
self.assertTrue(q1.empty())
141+
142+
143+
async def test_notify_sequence(self):
144+
q1 = asyncio.Queue(loop=self.loop)
145+
146+
cur_id = None
147+
148+
def notice_callb(con, message):
149+
q1.put_nowait((con, cur_id, message.message))
150+
151+
con = self.con
152+
await con.execute(
153+
"CREATE FUNCTION _test(i INT) RETURNS int LANGUAGE plpgsql AS $$"
154+
" BEGIN"
155+
" RAISE NOTICE '1_%', i;"
156+
" PERFORM pg_sleep(0.1);"
157+
" RAISE NOTICE '2_%', i;"
158+
" RETURN i;"
159+
" END"
160+
"$$"
161+
)
162+
con.add_notice_callback(notice_callb)
163+
for cur_id in range(10):
164+
await con.execute("SELECT _test($1)", cur_id)
165+
166+
for cur_id in range(10):
167+
self.assertEqual(
168+
q1.get_nowait(),
169+
(con, cur_id, '1_%s' % cur_id))
170+
self.assertEqual(
171+
q1.get_nowait(),
172+
(con, cur_id, '2_%s' % cur_id))
173+
174+
con.remove_notice_callback(notice_callb)
175+
self.assertTrue(q1.empty())

0 commit comments

Comments
 (0)