Skip to content

Commit 68b40cb

Browse files
authored
Fix set_type_codec() to accept standard SQL type names (#619)
Currently, `Connection.set_type_codec()` only accepts type names as they appear in `pg_catalog.pg_type` and would refuse to handle a standard SQL spelling of a type like `character varying`. This is an oversight, as the internal type names aren't really supposed to be treated as public Postgres API. Additionally, for historical reasons, Postgres has a single-byte `"char"` type, which is distinct from both `varchar` and SQL `char`, which may lead to massive confusion if a user sets up a custom codec on it expecting to handle the `char(n)` type instead. Issue: #617
1 parent 4a627d5 commit 68b40cb

File tree

5 files changed

+95
-22
lines changed

5 files changed

+95
-22
lines changed

asyncpg/connection.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,32 @@ async def _introspect_types(self, typeoids, timeout):
428428
return await self.__execute(
429429
self._intro_query, (list(typeoids),), 0, timeout)
430430

431+
async def _introspect_type(self, typename, schema):
432+
if (
433+
schema == 'pg_catalog'
434+
and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP
435+
):
436+
typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()]
437+
rows = await self._execute(
438+
introspection.TYPE_BY_OID,
439+
[typeoid],
440+
limit=0,
441+
timeout=None,
442+
)
443+
if rows:
444+
typeinfo = rows[0]
445+
else:
446+
typeinfo = None
447+
else:
448+
typeinfo = await self.fetchrow(
449+
introspection.TYPE_BY_NAME, typename, schema)
450+
451+
if not typeinfo:
452+
raise ValueError(
453+
'unknown type: {}.{}'.format(schema, typename))
454+
455+
return typeinfo
456+
431457
def cursor(
432458
self,
433459
query,
@@ -1110,12 +1136,7 @@ async def set_type_codec(self, typename, *,
11101136
``format``.
11111137
"""
11121138
self._check_open()
1113-
1114-
typeinfo = await self.fetchrow(
1115-
introspection.TYPE_BY_NAME, typename, schema)
1116-
if not typeinfo:
1117-
raise ValueError('unknown type: {}.{}'.format(schema, typename))
1118-
1139+
typeinfo = await self._introspect_type(typename, schema)
11191140
if not introspection.is_scalar_type(typeinfo):
11201141
raise ValueError(
11211142
'cannot use custom codec on non-scalar type {}.{}'.format(
@@ -1142,15 +1163,9 @@ async def reset_type_codec(self, typename, *, schema='public'):
11421163
.. versionadded:: 0.12.0
11431164
"""
11441165

1145-
typeinfo = await self.fetchrow(
1146-
introspection.TYPE_BY_NAME, typename, schema)
1147-
if not typeinfo:
1148-
raise ValueError('unknown type: {}.{}'.format(schema, typename))
1149-
1150-
oid = typeinfo['oid']
1151-
1166+
typeinfo = await self._introspect_type(typename, schema)
11521167
self._protocol.get_settings().remove_python_codec(
1153-
oid, typename, schema)
1168+
typeinfo['oid'], typename, schema)
11541169

11551170
# Statement cache is no longer valid due to codec changes.
11561171
self._drop_local_statement_cache()
@@ -1191,13 +1206,7 @@ async def set_builtin_type_codec(self, typename, *,
11911206
core data type. Added the *format* keyword argument.
11921207
"""
11931208
self._check_open()
1194-
1195-
typeinfo = await self.fetchrow(
1196-
introspection.TYPE_BY_NAME, typename, schema)
1197-
if not typeinfo:
1198-
raise exceptions.InterfaceError(
1199-
'unknown type: {}.{}'.format(schema, typename))
1200-
1209+
typeinfo = await self._introspect_type(typename, schema)
12011210
if not introspection.is_scalar_type(typeinfo):
12021211
raise exceptions.InterfaceError(
12031212
'cannot alias non-scalar type {}.{}'.format(

asyncpg/introspection.py

+12
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@
147147
'''
148148

149149

150+
TYPE_BY_OID = '''\
151+
SELECT
152+
t.oid,
153+
t.typelem AS elemtype,
154+
t.typtype AS kind
155+
FROM
156+
pg_catalog.pg_type AS t
157+
WHERE
158+
t.oid = $1
159+
'''
160+
161+
150162
# 'b' for a base type, 'd' for a domain, 'e' for enum.
151163
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
152164

asyncpg/protocol/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
# flake8: NOQA
78

8-
from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
9+
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

asyncpg/protocol/pgtypes.pxi

+18
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \
216216
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
217217
BUILTIN_TYPE_NAME_MAP['timestamptz']
218218

219+
BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
220+
BUILTIN_TYPE_NAME_MAP['timestamp']
221+
219222
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
220223
BUILTIN_TYPE_NAME_MAP['timetz']
224+
225+
BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
226+
BUILTIN_TYPE_NAME_MAP['time']
227+
228+
BUILTIN_TYPE_NAME_MAP['char'] = \
229+
BUILTIN_TYPE_NAME_MAP['bpchar']
230+
231+
BUILTIN_TYPE_NAME_MAP['character'] = \
232+
BUILTIN_TYPE_NAME_MAP['bpchar']
233+
234+
BUILTIN_TYPE_NAME_MAP['character varying'] = \
235+
BUILTIN_TYPE_NAME_MAP['varchar']
236+
237+
BUILTIN_TYPE_NAME_MAP['bit varying'] = \
238+
BUILTIN_TYPE_NAME_MAP['varbit']

tests/test_codecs.py

+33
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,39 @@ async def test_custom_codec_on_domain(self):
12551255
finally:
12561256
await self.con.execute('DROP DOMAIN custom_codec_t')
12571257

1258+
async def test_custom_codec_on_stdsql_types(self):
1259+
types = [
1260+
'smallint',
1261+
'int',
1262+
'integer',
1263+
'bigint',
1264+
'decimal',
1265+
'real',
1266+
'double precision',
1267+
'timestamp with timezone',
1268+
'time with timezone',
1269+
'timestamp without timezone',
1270+
'time without timezone',
1271+
'char',
1272+
'character',
1273+
'character varying',
1274+
'bit varying',
1275+
'CHARACTER VARYING'
1276+
]
1277+
1278+
for t in types:
1279+
with self.subTest(type=t):
1280+
try:
1281+
await self.con.set_type_codec(
1282+
t,
1283+
schema='pg_catalog',
1284+
encoder=str,
1285+
decoder=str,
1286+
format='text'
1287+
)
1288+
finally:
1289+
await self.con.reset_type_codec(t, schema='pg_catalog')
1290+
12581291
async def test_custom_codec_on_enum(self):
12591292
"""Test encoding/decoding using a custom codec on an enum."""
12601293
await self.con.execute('''

0 commit comments

Comments
 (0)