Skip to content

Commit f3a7b0e

Browse files
committed
Fix set_type_codec() to accept standard SQL type names
Currently, `Connection.set_type_codec()` only accepts type names as they appear in `pg_catalog.pg_type` and would refuse to handle a standard SQL spelling of a type like `character varying`. This is an oversight, as the internal type names aren't really supposed to be treated as public Postgres API. Additionally, for historical reasons, Postgres has a single-byte `"char"` type, which is distinct from both `varchar` and SQL `char`, which may lead to massive confusion if a user sets up a custom codec on it expecting to handle the `char(n)` type instead. Issue: #617.
1 parent 2bac166 commit f3a7b0e

File tree

5 files changed

+93
-22
lines changed

5 files changed

+93
-22
lines changed

asyncpg/connection.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,30 @@ async def _introspect_types(self, typeoids, timeout):
426426
return await self.__execute(
427427
self._intro_query, (list(typeoids),), 0, timeout)
428428

429+
async def _introspect_type(self, typename, schema):
430+
if (
431+
schema == 'pg_catalog'
432+
and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP
433+
):
434+
typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()]
435+
rows = await self._execute(
436+
introspection.TYPE_BY_OID,
437+
[typeoid],
438+
limit=0,
439+
timeout=None,
440+
)
441+
if rows:
442+
typeinfo = rows[0]
443+
else:
444+
typeinfo = await self.fetchrow(
445+
introspection.TYPE_BY_NAME, typename, schema)
446+
447+
if not typeinfo:
448+
raise ValueError(
449+
'unknown type: {}.{}'.format(schema, typename))
450+
451+
return typeinfo
452+
429453
def cursor(
430454
self,
431455
query,
@@ -1108,12 +1132,7 @@ async def set_type_codec(self, typename, *,
11081132
``format``.
11091133
"""
11101134
self._check_open()
1111-
1112-
typeinfo = await self.fetchrow(
1113-
introspection.TYPE_BY_NAME, typename, schema)
1114-
if not typeinfo:
1115-
raise ValueError('unknown type: {}.{}'.format(schema, typename))
1116-
1135+
typeinfo = await self._introspect_type(typename, schema)
11171136
if not introspection.is_scalar_type(typeinfo):
11181137
raise ValueError(
11191138
'cannot use custom codec on non-scalar type {}.{}'.format(
@@ -1140,15 +1159,9 @@ async def reset_type_codec(self, typename, *, schema='public'):
11401159
.. versionadded:: 0.12.0
11411160
"""
11421161

1143-
typeinfo = await self.fetchrow(
1144-
introspection.TYPE_BY_NAME, typename, schema)
1145-
if not typeinfo:
1146-
raise ValueError('unknown type: {}.{}'.format(schema, typename))
1147-
1148-
oid = typeinfo['oid']
1149-
1162+
typeinfo = await self._introspect_type(typename, schema)
11501163
self._protocol.get_settings().remove_python_codec(
1151-
oid, typename, schema)
1164+
typeinfo['oid'], typename, schema)
11521165

11531166
# Statement cache is no longer valid due to codec changes.
11541167
self._drop_local_statement_cache()
@@ -1189,13 +1202,7 @@ async def set_builtin_type_codec(self, typename, *,
11891202
core data type. Added the *format* keyword argument.
11901203
"""
11911204
self._check_open()
1192-
1193-
typeinfo = await self.fetchrow(
1194-
introspection.TYPE_BY_NAME, typename, schema)
1195-
if not typeinfo:
1196-
raise exceptions.InterfaceError(
1197-
'unknown type: {}.{}'.format(schema, typename))
1198-
1205+
typeinfo = await self._introspect_type(typename, schema)
11991206
if not introspection.is_scalar_type(typeinfo):
12001207
raise exceptions.InterfaceError(
12011208
'cannot alias non-scalar type {}.{}'.format(

asyncpg/introspection.py

+12
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@
147147
'''
148148

149149

150+
TYPE_BY_OID = '''\
151+
SELECT
152+
t.oid,
153+
t.typelem AS elemtype,
154+
t.typtype AS kind
155+
FROM
156+
pg_catalog.pg_type AS t
157+
WHERE
158+
t.oid = $1
159+
'''
160+
161+
150162
# 'b' for a base type, 'd' for a domain, 'e' for enum.
151163
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
152164

asyncpg/protocol/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
# flake8: NOQA
78

8-
from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
9+
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

asyncpg/protocol/pgtypes.pxi

+18
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \
216216
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
217217
BUILTIN_TYPE_NAME_MAP['timestamptz']
218218

219+
BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
220+
BUILTIN_TYPE_NAME_MAP['timestamp']
221+
219222
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
220223
BUILTIN_TYPE_NAME_MAP['timetz']
224+
225+
BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
226+
BUILTIN_TYPE_NAME_MAP['time']
227+
228+
BUILTIN_TYPE_NAME_MAP['char'] = \
229+
BUILTIN_TYPE_NAME_MAP['bpchar']
230+
231+
BUILTIN_TYPE_NAME_MAP['character'] = \
232+
BUILTIN_TYPE_NAME_MAP['bpchar']
233+
234+
BUILTIN_TYPE_NAME_MAP['character varying'] = \
235+
BUILTIN_TYPE_NAME_MAP['varchar']
236+
237+
BUILTIN_TYPE_NAME_MAP['bit varying'] = \
238+
BUILTIN_TYPE_NAME_MAP['varbit']

tests/test_codecs.py

+33
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,39 @@ async def test_custom_codec_on_domain(self):
12551255
finally:
12561256
await self.con.execute('DROP DOMAIN custom_codec_t')
12571257

1258+
async def test_custom_codec_on_stdsql_types(self):
1259+
types = [
1260+
'smallint',
1261+
'int',
1262+
'integer',
1263+
'bigint',
1264+
'decimal',
1265+
'real',
1266+
'double precision',
1267+
'timestamp with timezone',
1268+
'time with timezone',
1269+
'timestamp without timezone',
1270+
'time without timezone',
1271+
'char',
1272+
'character',
1273+
'character varying',
1274+
'bit varying',
1275+
'CHARACTER VARYING'
1276+
]
1277+
1278+
for t in types:
1279+
with self.subTest(type=t):
1280+
try:
1281+
await self.con.set_type_codec(
1282+
t,
1283+
schema='pg_catalog',
1284+
encoder=str,
1285+
decoder=str,
1286+
format='text',
1287+
)
1288+
finally:
1289+
await self.con.reset_type_codec(t, schema='pg_catalog')
1290+
12581291
async def test_custom_codec_on_enum(self):
12591292
"""Test encoding/decoding using a custom codec on an enum."""
12601293
await self.con.execute('''

0 commit comments

Comments
 (0)