Skip to content

Commit a0ce948

Browse files
committed
Disable custom data codec for internal introspection
Fixes: #617
1 parent 68b40cb commit a0ce948

File tree

9 files changed

+78
-37
lines changed

9 files changed

+78
-37
lines changed

asyncpg/connection.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ async def _get_statement(
342342
*,
343343
named: bool=False,
344344
use_cache: bool=True,
345+
ignore_custom_codec=False,
345346
record_class=None
346347
):
347348
if record_class is None:
@@ -403,7 +404,7 @@ async def _get_statement(
403404

404405
# Now that types have been resolved, populate the codec pipeline
405406
# for the statement.
406-
statement._init_codecs()
407+
statement._init_codecs(ignore_custom_codec)
407408

408409
if need_reprepare:
409410
await self._protocol.prepare(
@@ -426,7 +427,12 @@ async def _get_statement(
426427

427428
async def _introspect_types(self, typeoids, timeout):
428429
return await self.__execute(
429-
self._intro_query, (list(typeoids),), 0, timeout)
430+
self._intro_query,
431+
(list(typeoids),),
432+
0,
433+
timeout,
434+
ignore_custom_codec=True,
435+
)
430436

431437
async def _introspect_type(self, typename, schema):
432438
if (
@@ -439,20 +445,22 @@ async def _introspect_type(self, typename, schema):
439445
[typeoid],
440446
limit=0,
441447
timeout=None,
448+
ignore_custom_codec=True,
442449
)
443-
if rows:
444-
typeinfo = rows[0]
445-
else:
446-
typeinfo = None
447450
else:
448-
typeinfo = await self.fetchrow(
449-
introspection.TYPE_BY_NAME, typename, schema)
451+
rows = await self._execute(
452+
introspection.TYPE_BY_NAME,
453+
[typename, schema],
454+
limit=1,
455+
timeout=None,
456+
ignore_custom_codec=True,
457+
)
450458

451-
if not typeinfo:
459+
if not rows:
452460
raise ValueError(
453461
'unknown type: {}.{}'.format(schema, typename))
454462

455-
return typeinfo
463+
return rows[0]
456464

457465
def cursor(
458466
self,
@@ -1589,6 +1597,7 @@ async def _execute(
15891597
timeout,
15901598
*,
15911599
return_status=False,
1600+
ignore_custom_codec=False,
15921601
record_class=None
15931602
):
15941603
with self._stmt_exclusive_section:
@@ -1599,6 +1608,7 @@ async def _execute(
15991608
timeout,
16001609
return_status=return_status,
16011610
record_class=record_class,
1611+
ignore_custom_codec=ignore_custom_codec,
16021612
)
16031613
return result
16041614

@@ -1610,6 +1620,7 @@ async def __execute(
16101620
timeout,
16111621
*,
16121622
return_status=False,
1623+
ignore_custom_codec=False,
16131624
record_class=None
16141625
):
16151626
executor = lambda stmt, timeout: self._protocol.bind_execute(
@@ -1620,6 +1631,7 @@ async def __execute(
16201631
executor,
16211632
timeout,
16221633
record_class=record_class,
1634+
ignore_custom_codec=ignore_custom_codec,
16231635
)
16241636

16251637
async def _executemany(self, query, args, timeout):
@@ -1637,20 +1649,23 @@ async def _do_execute(
16371649
timeout,
16381650
retry=True,
16391651
*,
1652+
ignore_custom_codec=False,
16401653
record_class=None
16411654
):
16421655
if timeout is None:
16431656
stmt = await self._get_statement(
16441657
query,
16451658
None,
16461659
record_class=record_class,
1660+
ignore_custom_codec=ignore_custom_codec,
16471661
)
16481662
else:
16491663
before = time.monotonic()
16501664
stmt = await self._get_statement(
16511665
query,
16521666
timeout,
16531667
record_class=record_class,
1668+
ignore_custom_codec=ignore_custom_codec,
16541669
)
16551670
after = time.monotonic()
16561671
timeout -= after - before

asyncpg/protocol/codecs/base.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,6 @@ cdef class DataCodecConfig:
166166
dict _derived_type_codecs
167167
dict _custom_type_codecs
168168

169-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
169+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170+
bint ignore_custom_codec=*)
170171
cdef inline Codec get_any_local_codec(self, uint32_t oid)

asyncpg/protocol/codecs/base.pyx

+12-10
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,20 @@ cdef class DataCodecConfig:
692692

693693
return codec
694694

695-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
695+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
696+
bint ignore_custom_codec=False):
696697
cdef Codec codec
697698

698-
codec = self.get_any_local_codec(oid)
699-
if codec is not None:
700-
if codec.format != format:
701-
# The codec for this OID has been overridden by
702-
# set_{builtin}_type_codec with a different format.
703-
# We must respect that and not return a core codec.
704-
return None
705-
else:
706-
return codec
699+
if not ignore_custom_codec:
700+
codec = self.get_any_local_codec(oid)
701+
if codec is not None:
702+
if codec.format != format:
703+
# The codec for this OID has been overridden by
704+
# set_{builtin}_type_codec with a different format.
705+
# We must respect that and not return a core codec.
706+
return None
707+
else:
708+
return codec
707709

708710
codec = get_core_codec(oid, format)
709711
if codec is not None:

asyncpg/protocol/prepared_stmt.pxd

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ cdef class PreparedStatementState:
2929
tuple rows_codecs
3030

3131
cdef _encode_bind_msg(self, args)
32-
cpdef _init_codecs(self)
33-
cdef _ensure_rows_decoder(self)
34-
cdef _ensure_args_encoder(self)
32+
cpdef _init_codecs(self, bint ignore_custom_codec)
33+
cdef _ensure_rows_decoder(self, bint ignore_custom_codec)
34+
cdef _ensure_args_encoder(self, bint ignore_custom_codec)
3535
cdef _set_row_desc(self, object desc)
3636
cdef _set_args_desc(self, object desc)
3737
cdef _decode_row(self, const char* cbuf, ssize_t buf_len)

asyncpg/protocol/prepared_stmt.pyx

+10-7
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ cdef class PreparedStatementState:
8686

8787
return missing
8888

89-
cpdef _init_codecs(self):
90-
self._ensure_args_encoder()
91-
self._ensure_rows_decoder()
89+
cpdef _init_codecs(self, bint ignore_custom_codec):
90+
91+
self._ensure_args_encoder(ignore_custom_codec)
92+
self._ensure_rows_decoder(ignore_custom_codec)
9293

9394
def attach(self):
9495
self.refs += 1
@@ -180,7 +181,7 @@ cdef class PreparedStatementState:
180181

181182
return writer
182183

183-
cdef _ensure_rows_decoder(self):
184+
cdef _ensure_rows_decoder(self, bint ignore_custom_codec):
184185
cdef:
185186
list cols_names
186187
object cols_mapping
@@ -205,7 +206,8 @@ cdef class PreparedStatementState:
205206
cols_mapping[col_name] = i
206207
cols_names.append(col_name)
207208
oid = row[3]
208-
codec = self.settings.get_data_codec(oid)
209+
codec = self.settings.get_data_codec(
210+
oid, ignore_custom_codec=ignore_custom_codec)
209211
if codec is None or not codec.has_decoder():
210212
raise exceptions.InternalClientError(
211213
'no decoder for OID {}'.format(oid))
@@ -219,7 +221,7 @@ cdef class PreparedStatementState:
219221

220222
self.rows_codecs = tuple(codecs)
221223

222-
cdef _ensure_args_encoder(self):
224+
cdef _ensure_args_encoder(self, bint ignore_custom_codec):
223225
cdef:
224226
uint32_t p_oid
225227
Codec codec
@@ -230,7 +232,8 @@ cdef class PreparedStatementState:
230232

231233
for i from 0 <= i < self.args_num:
232234
p_oid = self.parameters_desc[i]
233-
codec = self.settings.get_data_codec(p_oid)
235+
codec = self.settings.get_data_codec(
236+
p_oid, ignore_custom_codec=ignore_custom_codec)
234237
if codec is None or not codec.has_encoder():
235238
raise exceptions.InternalClientError(
236239
'no encoder for OID {}'.format(p_oid))

asyncpg/protocol/protocol.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ cdef class BaseProtocol(CoreProtocol):
411411
# No header extension
412412
wbuf.write_int32(0)
413413

414-
record_stmt._ensure_rows_decoder()
414+
record_stmt._ensure_rows_decoder(False)
415415
codecs = record_stmt.rows_codecs
416416
num_cols = len(codecs)
417417
settings = self.settings

asyncpg/protocol/settings.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
2626
cpdef inline set_builtin_type_codec(
2727
self, typeoid, typename, typeschema, typekind, alias_to, format)
2828
cpdef inline Codec get_data_codec(
29-
self, uint32_t oid, ServerDataFormat format=*)
29+
self, uint32_t oid, ServerDataFormat format=*,
30+
bint ignore_custom_codec=*)

asyncpg/protocol/settings.pyx

+8-4
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8787
typekind, alias_to, _format)
8888

8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
90-
ServerDataFormat format=PG_FORMAT_ANY):
90+
ServerDataFormat format=PG_FORMAT_ANY,
91+
bint ignore_custom_codec=False):
9192
if format == PG_FORMAT_ANY:
92-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
93+
codec = self._data_codecs.get_codec(
94+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
9395
if codec is None:
94-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
96+
codec = self._data_codecs.get_codec(
97+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
9598
return codec
9699
else:
97-
return self._data_codecs.get_codec(oid, format)
100+
return self._data_codecs.get_codec(
101+
oid, format, ignore_custom_codec)
98102

99103
def __getattr__(self, name):
100104
if not name.startswith('_'):

tests/test_introspection.py

+15
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ def tearDownClass(cls):
4343

4444
super().tearDownClass()
4545

46+
def setUp(self):
47+
super().setUp()
48+
self.loop.run_until_complete(self._add_custom_codec(self.con))
49+
50+
async def _add_custom_codec(self, conn):
51+
# mess up with the codec - builtin introspection shouldn't be affected
52+
await conn.set_type_codec(
53+
"oid",
54+
schema="pg_catalog",
55+
encoder=lambda value: None,
56+
decoder=lambda value: None,
57+
format="text",
58+
)
59+
4660
@tb.with_connection_options(database='asyncpg_intro_test')
4761
async def test_introspection_on_large_db(self):
4862
await self.con.execute(
@@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
142156
# query would cause introspection to retry.
143157
slow_intro_conn = await self.connect(
144158
connection_class=SlowIntrospectionConnection)
159+
await self._add_custom_codec(slow_intro_conn)
145160
try:
146161
await self.con.execute('''
147162
CREATE DOMAIN intro_1_t AS int;

0 commit comments

Comments
 (0)