Skip to content

Commit 5ddd7fc

Browse files
committed
Raise an error if connection is used after being closed
1 parent a2935ae commit 5ddd7fc

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

asyncpg/connection.py

+20
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ async def add_listener(self, channel, callback):
118118
**channel**: name of the channel the notification was sent to;
119119
**payload**: the payload.
120120
"""
121+
self.__check_open()
121122
if channel not in self._listeners:
122123
await self.fetch('LISTEN {}'.format(channel))
123124
self._listeners[channel] = set()
@@ -181,6 +182,7 @@ def transaction(self, *, isolation='read_committed', readonly=False,
181182
.. _`PostgreSQL documentation`: https://www.postgresql.org/docs/\
182183
current/static/sql-set-transaction.html
183184
"""
185+
self.__check_open()
184186
return transaction.Transaction(self, isolation, readonly, deferrable)
185187

186188
async def execute(self, query: str, *args, timeout: float=None) -> str:
@@ -211,6 +213,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
211213
.. versionchanged:: 0.5.4
212214
Made it possible to pass query arguments.
213215
"""
216+
self.__check_open()
217+
214218
if not args:
215219
return await self._protocol.query(query, timeout)
216220

@@ -240,6 +244,8 @@ async def executemany(self, command: str, args,
240244
241245
.. versionadded:: 0.7.0
242246
"""
247+
self.__check_open()
248+
243249
if 'timeout' in kw:
244250
timeout = kw.pop('timeout')
245251
else:
@@ -311,6 +317,7 @@ def cursor(self, query, *args, prefetch=None, timeout=None):
311317
312318
:return: A :class:`~cursor.CursorFactory` object.
313319
"""
320+
self.__check_open()
314321
return cursor.CursorFactory(self, query, None, args,
315322
prefetch, timeout)
316323

@@ -322,6 +329,7 @@ async def prepare(self, query, *, timeout=None):
322329
323330
:return: A :class:`~prepared_stmt.PreparedStatement` instance.
324331
"""
332+
self.__check_open()
325333
stmt = await self._get_statement(query, timeout, named=True)
326334
return prepared_stmt.PreparedStatement(self, query, stmt)
327335

@@ -334,6 +342,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
334342
335343
:return list: A list of :class:`Record` instances.
336344
"""
345+
self.__check_open()
337346
return await self._execute(query, args, 0, timeout)
338347

339348
async def fetchval(self, query, *args, column=0, timeout=None):
@@ -350,6 +359,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
350359
351360
:return: The value of the specified column of the first record.
352361
"""
362+
self.__check_open()
353363
data = await self._execute(query, args, 1, timeout)
354364
if not data:
355365
return None
@@ -364,6 +374,7 @@ async def fetchrow(self, query, *args, timeout=None):
364374
365375
:return: The first row as a :class:`Record` instance.
366376
"""
377+
self.__check_open()
367378
data = await self._execute(query, args, 1, timeout)
368379
if not data:
369380
return None
@@ -384,6 +395,8 @@ async def set_type_codec(self, typename, *,
384395
data. If ``False`` (the default), the data is
385396
expected to be encoded/decoded in text.
386397
"""
398+
self.__check_open()
399+
387400
if self._type_by_name_stmt is None:
388401
self._type_by_name_stmt = await self.prepare(
389402
introspection.TYPE_BY_NAME)
@@ -412,6 +425,8 @@ async def set_builtin_type_codec(self, typename, *,
412425
(defaults to 'public')
413426
:param codec_name: The name of the builtin codec.
414427
"""
428+
self.__check_open()
429+
415430
if self._type_by_name_stmt is None:
416431
self._type_by_name_stmt = await self.prepare(
417432
introspection.TYPE_BY_NAME)
@@ -455,11 +470,16 @@ def terminate(self):
455470
self._protocol.abort()
456471

457472
async def reset(self):
473+
self.__check_open()
458474
self._listeners.clear()
459475
reset_query = self._get_reset_query()
460476
if reset_query:
461477
await self.execute(reset_query)
462478

479+
def __check_open(self):
480+
if self.is_closed():
481+
raise exceptions.InterfaceError('connection is closed')
482+
463483
def _get_unique_id(self, prefix):
464484
self._uid += 1
465485
return '__asyncpg_{}_{}__'.format(prefix, self._uid)

tests/test_connect.py

+31
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,34 @@ async def test_connection_isinstance(self):
419419
self.assertTrue(isinstance(self.con, connection.Connection))
420420
self.assertTrue(isinstance(self.con, object))
421421
self.assertFalse(isinstance(self.con, list))
422+
423+
async def test_connection_use_after_close(self):
424+
def check():
425+
return self.assertRaisesRegex(asyncpg.InterfaceError,
426+
'connection is closed')
427+
428+
await self.con.close()
429+
430+
with check():
431+
await self.con.add_listener('aaa', lambda: None)
432+
433+
with check():
434+
self.con.transaction()
435+
436+
with check():
437+
await self.con.executemany('SELECT 1', [])
438+
439+
with check():
440+
await self.con.set_type_codec('aaa', encoder=None, decoder=None)
441+
442+
with check():
443+
await self.con.set_builtin_type_codec('aaa', codec_name='aaa')
444+
445+
for meth in ('execute', 'fetch', 'fetchval', 'fetchrow',
446+
'prepare', 'cursor'):
447+
448+
with check():
449+
await getattr(self.con, meth)('SELECT 1')
450+
451+
with check():
452+
await self.con.reset()

0 commit comments

Comments
 (0)