Skip to content

Commit ce04ac6

Browse files
committed
Untangle custom codec confusion
Asyncpg currently erroneously prefers binary I/O for underlying type of arrays effectively ignoring a possible custom text codec that might have been configured on a type. Fix this by removing the explicit preference for binary I/O, so that the codec selection preference is now in the following order: - custom binary codec - custom text codec - builtin binary codec - builtin text codec Fixes: #590 Reported-by: @neumond
1 parent e585661 commit ce04ac6

File tree

6 files changed

+123
-105
lines changed

6 files changed

+123
-105
lines changed

asyncpg/connection.py

+9
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,15 @@ async def set_type_codec(self, typename, *,
11561156
.. versionchanged:: 0.13.0
11571157
The ``binary`` keyword argument was removed in favor of
11581158
``format``.
1159+
1160+
.. note::
1161+
1162+
It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
1163+
whenever possible and if the underlying type supports it. Asyncpg
1164+
currently does not support text I/O for composite and range types,
1165+
and some other functionality, such as
1166+
:meth:`Connection.copy_to_table`, does not support types with text
1167+
codecs.
11591168
"""
11601169
self._check_open()
11611170
typeinfo = await self._introspect_type(typename, schema)

asyncpg/introspection.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,9 @@
3737
3838
ELSE NULL
3939
END) AS basetype,
40-
t.typreceive::oid != 0 AND t.typsend::oid != 0
41-
AS has_bin_io,
4240
t.typelem AS elemtype,
4341
elem_t.typdelim AS elemdelim,
4442
range_t.rngsubtype AS range_subtype,
45-
(CASE WHEN t.typtype = 'r' THEN
46-
(SELECT
47-
range_elem_t.typreceive::oid != 0 AND
48-
range_elem_t.typsend::oid != 0
49-
FROM
50-
pg_catalog.pg_type AS range_elem_t
51-
WHERE
52-
range_elem_t.oid = range_t.rngsubtype)
53-
ELSE
54-
elem_t.typreceive::oid != 0 AND
55-
elem_t.typsend::oid != 0
56-
END) AS elem_has_bin_io,
5743
(CASE WHEN t.typtype = 'c' THEN
5844
(SELECT
5945
array_agg(ia.atttypid ORDER BY ia.attnum)
@@ -98,12 +84,12 @@
9884

9985
INTRO_LOOKUP_TYPES = '''\
10086
WITH RECURSIVE typeinfo_tree(
101-
oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
102-
range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
87+
oid, ns, name, kind, basetype, elemtype, elemdelim,
88+
range_subtype, attrtypoids, attrnames, depth)
10389
AS (
10490
SELECT
105-
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
106-
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
91+
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
92+
ti.elemtype, ti.elemdelim, ti.range_subtype,
10793
ti.attrtypoids, ti.attrnames, 0
10894
FROM
10995
{typeinfo} AS ti
@@ -113,8 +99,8 @@
11399
UNION ALL
114100
115101
SELECT
116-
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
117-
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
102+
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
103+
ti.elemtype, ti.elemdelim, ti.range_subtype,
118104
ti.attrtypoids, ti.attrnames, tt.depth + 1
119105
FROM
120106
{typeinfo} ti,
@@ -126,7 +112,10 @@
126112
)
127113
128114
SELECT DISTINCT
129-
*
115+
*,
116+
basetype::regtype::text AS basetype_name,
117+
elemtype::regtype::text AS elemtype_name,
118+
range_subtype::regtype::text AS range_subtype_name
130119
FROM
131120
typeinfo_tree
132121
ORDER BY

asyncpg/protocol/codecs/base.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,5 @@ cdef class DataCodecConfig:
168168

169169
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170170
bint ignore_custom_codec=*)
171-
cdef inline Codec get_any_local_codec(self, uint32_t oid)
171+
cdef inline Codec get_custom_codec(self, uint32_t oid,
172+
ServerDataFormat format)

asyncpg/protocol/codecs/base.pyx

+64-73
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,7 @@ cdef class DataCodecConfig:
440440
for ti in types:
441441
oid = ti['oid']
442442

443-
if not ti['has_bin_io']:
444-
format = PG_FORMAT_TEXT
445-
else:
446-
format = PG_FORMAT_BINARY
447-
448-
has_text_elements = False
449-
450-
if self.get_codec(oid, format) is not None:
443+
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
451444
continue
452445

453446
name = ti['name']
@@ -468,92 +461,79 @@ cdef class DataCodecConfig:
468461
name = name[1:]
469462
name = '{}[]'.format(name)
470463

471-
if ti['elem_has_bin_io']:
472-
elem_format = PG_FORMAT_BINARY
473-
else:
474-
elem_format = PG_FORMAT_TEXT
475-
476-
elem_codec = self.get_codec(array_element_oid, elem_format)
464+
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
477465
if elem_codec is None:
478-
elem_format = PG_FORMAT_TEXT
479466
elem_codec = self.declare_fallback_codec(
480-
array_element_oid, name, schema)
467+
array_element_oid, ti['elemtype_name'], schema)
481468

482469
elem_delim = <Py_UCS4>ti['elemdelim'][0]
483470

484-
self._derived_type_codecs[oid, elem_format] = \
471+
self._derived_type_codecs[oid, elem_codec.format] = \
485472
Codec.new_array_codec(
486473
oid, name, schema, elem_codec, elem_delim)
487474

488475
elif ti['kind'] == b'c':
476+
# Composite type
477+
489478
if not comp_type_attrs:
490479
raise exceptions.InternalClientError(
491-
'type record missing field types for '
492-
'composite {}'.format(oid))
493-
494-
# Composite type
480+
f'type record missing field types for composite {oid}')
495481

496482
comp_elem_codecs = []
483+
has_text_elements = False
497484

498485
for typoid in comp_type_attrs:
499-
elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY)
500-
if elem_codec is None:
501-
elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT)
502-
has_text_elements = True
486+
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
503487
if elem_codec is None:
504488
raise exceptions.InternalClientError(
505-
'no codec for composite attribute type {}'.format(
506-
typoid))
489+
f'no codec for composite attribute type {typoid}')
490+
if elem_codec.format is PG_FORMAT_TEXT:
491+
has_text_elements = True
507492
comp_elem_codecs.append(elem_codec)
508493

509494
element_names = collections.OrderedDict()
510495
for i, attrname in enumerate(ti['attrnames']):
511496
element_names[attrname] = i
512497

498+
# If at least one element is text-encoded, we must
499+
# encode the whole composite as text.
513500
if has_text_elements:
514-
format = PG_FORMAT_TEXT
501+
elem_format = PG_FORMAT_TEXT
502+
else:
503+
elem_format = PG_FORMAT_BINARY
515504

516-
self._derived_type_codecs[oid, format] = \
505+
self._derived_type_codecs[oid, elem_format] = \
517506
Codec.new_composite_codec(
518-
oid, name, schema, format, comp_elem_codecs,
507+
oid, name, schema, elem_format, comp_elem_codecs,
519508
comp_type_attrs, element_names)
520509

521510
elif ti['kind'] == b'd':
522511
# Domain type
523512

524513
if not base_type:
525514
raise exceptions.InternalClientError(
526-
'type record missing base type for domain {}'.format(
527-
oid))
515+
f'type record missing base type for domain {oid}')
528516

529-
elem_codec = self.get_codec(base_type, format)
517+
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
530518
if elem_codec is None:
531-
format = PG_FORMAT_TEXT
532519
elem_codec = self.declare_fallback_codec(
533-
base_type, name, schema)
520+
base_type, ti['basetype_name'], schema)
534521

535-
self._derived_type_codecs[oid, format] = elem_codec
522+
self._derived_type_codecs[oid, elem_codec.format] = elem_codec
536523

537524
elif ti['kind'] == b'r':
538525
# Range type
539526

540527
if not range_subtype_oid:
541528
raise exceptions.InternalClientError(
542-
'type record missing base type for range {}'.format(
543-
oid))
529+
f'type record missing base type for range {oid}')
544530

545-
if ti['elem_has_bin_io']:
546-
elem_format = PG_FORMAT_BINARY
547-
else:
548-
elem_format = PG_FORMAT_TEXT
549-
550-
elem_codec = self.get_codec(range_subtype_oid, elem_format)
531+
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
551532
if elem_codec is None:
552-
elem_format = PG_FORMAT_TEXT
553533
elem_codec = self.declare_fallback_codec(
554-
range_subtype_oid, name, schema)
534+
range_subtype_oid, ti['range_subtype_name'], schema)
555535

556-
self._derived_type_codecs[oid, elem_format] = \
536+
self._derived_type_codecs[oid, elem_codec.format] = \
557537
Codec.new_range_codec(oid, name, schema, elem_codec)
558538

559539
elif ti['kind'] == b'e':
@@ -665,10 +645,6 @@ cdef class DataCodecConfig:
665645
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
666646
cdef Codec codec
667647

668-
codec = self.get_codec(oid, PG_FORMAT_TEXT)
669-
if codec is not None:
670-
return codec
671-
672648
if oid <= MAXBUILTINOID:
673649
# This is a BKI type, for which asyncpg has no
674650
# defined codec. This should only happen for newly
@@ -696,34 +672,49 @@ cdef class DataCodecConfig:
696672
bint ignore_custom_codec=False):
697673
cdef Codec codec
698674

699-
if not ignore_custom_codec:
700-
codec = self.get_any_local_codec(oid)
701-
if codec is not None:
702-
if codec.format != format:
703-
# The codec for this OID has been overridden by
704-
# set_{builtin}_type_codec with a different format.
705-
# We must respect that and not return a core codec.
706-
return None
707-
else:
708-
return codec
709-
710-
codec = get_core_codec(oid, format)
711-
if codec is not None:
675+
if format == PG_FORMAT_ANY:
676+
codec = self.get_codec(
677+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
678+
if codec is None:
679+
codec = self.get_codec(
680+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
712681
return codec
713682
else:
714-
try:
715-
return self._derived_type_codecs[oid, format]
716-
except KeyError:
717-
return None
683+
if not ignore_custom_codec:
684+
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
685+
if codec is not None:
686+
if codec.format != format:
687+
# The codec for this OID has been overridden by
688+
# set_{builtin}_type_codec with a different format.
689+
# We must respect that and not return a core codec.
690+
return None
691+
else:
692+
return codec
693+
694+
codec = get_core_codec(oid, format)
695+
if codec is not None:
696+
return codec
697+
else:
698+
try:
699+
return self._derived_type_codecs[oid, format]
700+
except KeyError:
701+
return None
718702

719-
cdef inline Codec get_any_local_codec(self, uint32_t oid):
703+
cdef inline Codec get_custom_codec(
704+
self,
705+
uint32_t oid,
706+
ServerDataFormat format
707+
):
720708
cdef Codec codec
721709

722-
codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
723-
if codec is None:
724-
return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
710+
if format == PG_FORMAT_ANY:
711+
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
712+
if codec is None:
713+
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
725714
else:
726-
return codec
715+
codec = self._custom_type_codecs.get((oid, format))
716+
717+
return codec
727718

728719

729720
cdef inline Codec get_core_codec(

asyncpg/protocol/settings.pyx

+1-10
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
9090
ServerDataFormat format=PG_FORMAT_ANY,
9191
bint ignore_custom_codec=False):
92-
if format == PG_FORMAT_ANY:
93-
codec = self._data_codecs.get_codec(
94-
oid, PG_FORMAT_BINARY, ignore_custom_codec)
95-
if codec is None:
96-
codec = self._data_codecs.get_codec(
97-
oid, PG_FORMAT_TEXT, ignore_custom_codec)
98-
return codec
99-
else:
100-
return self._data_codecs.get_codec(
101-
oid, format, ignore_custom_codec)
92+
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)
10293

10394
def __getattr__(self, name):
10495
if not name.startswith('_'):

tests/test_codecs.py

+37
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,34 @@ async def test_custom_codec_on_enum(self):
13051305
finally:
13061306
await self.con.execute('DROP TYPE custom_codec_t')
13071307

1308+
async def test_custom_codec_on_enum_array(self):
1309+
"""Test encoding/decoding using a custom codec on an enum array.
1310+
1311+
Bug: https://github.com/MagicStack/asyncpg/issues/590
1312+
"""
1313+
await self.con.execute('''
1314+
CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz')
1315+
''')
1316+
1317+
try:
1318+
await self.con.set_type_codec(
1319+
'custom_codec_t',
1320+
encoder=lambda v: str(v).lstrip('enum :'),
1321+
decoder=lambda v: 'enum: ' + str(v))
1322+
1323+
v = await self.con.fetchval(
1324+
"SELECT ARRAY['foo', 'bar']::custom_codec_t[]")
1325+
self.assertEqual(v, ['enum: foo', 'enum: bar'])
1326+
1327+
v = await self.con.fetchval(
1328+
'SELECT ARRAY[$1]::custom_codec_t[]', 'foo')
1329+
self.assertEqual(v, ['enum: foo'])
1330+
1331+
v = await self.con.fetchval("SELECT 'foo'::custom_codec_t")
1332+
self.assertEqual(v, 'enum: foo')
1333+
finally:
1334+
await self.con.execute('DROP TYPE custom_codec_t')
1335+
13081336
async def test_custom_codec_override_binary(self):
13091337
"""Test overriding core codecs."""
13101338
import json
@@ -1350,6 +1378,14 @@ def _decoder(value):
13501378
res = await conn.fetchval('SELECT $1::json', data)
13511379
self.assertEqual(data, res)
13521380

1381+
res = await conn.fetchval('SELECT $1::json[]', [data])
1382+
self.assertEqual([data], res)
1383+
1384+
await conn.execute('CREATE DOMAIN my_json AS json')
1385+
1386+
res = await conn.fetchval('SELECT $1::my_json', data)
1387+
self.assertEqual(data, res)
1388+
13531389
def _encoder(value):
13541390
return value
13551391

@@ -1365,6 +1401,7 @@ def _decoder(value):
13651401
res = await conn.fetchval('SELECT $1::uuid', data)
13661402
self.assertEqual(res, data)
13671403
finally:
1404+
await conn.execute('DROP DOMAIN IF EXISTS my_json')
13681405
await conn.close()
13691406

13701407
async def test_custom_codec_override_tuple(self):

0 commit comments

Comments
 (0)