Skip to content

Commit 71de129

Browse files
committed
Allow overriding codecs for builtin types on a connection.
Connection.set_type_codec now allows overriding codecs for builtin types. Per complaint in #73.
1 parent d207467 commit 71de129

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

asyncpg/protocol/codecs/base.pyx

+9-13
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,6 @@ cdef class DataCodecConfig:
474474
encoder, decoder, binary):
475475
format = PG_FORMAT_BINARY if binary else PG_FORMAT_TEXT
476476

477-
if self.get_codec(typeoid, format) is not None:
478-
raise ValueError('cannot override codec for type {}'.format(
479-
typeoid))
480-
481477
self._local_type_codecs[typeoid, format] = \
482478
Codec.new_python_codec(typeoid, typename, typeschema, typekind,
483479
encoder, decoder, format)
@@ -519,17 +515,17 @@ cdef class DataCodecConfig:
519515
cdef inline Codec get_codec(self, uint32_t oid, CodecFormat format):
520516
cdef Codec codec
521517

522-
codec = get_core_codec(oid, format)
523-
if codec is not None:
524-
return codec
525-
526518
try:
527-
return self._type_codecs_cache[oid, format]
519+
return self._local_type_codecs[oid, format]
528520
except KeyError:
529-
try:
530-
return self._local_type_codecs[oid, format]
531-
except KeyError:
532-
return None
521+
codec = get_core_codec(oid, format)
522+
if codec is not None:
523+
return codec
524+
else:
525+
try:
526+
return self._type_codecs_cache[oid, format]
527+
except KeyError:
528+
return None
533529

534530

535531
cdef inline Codec get_core_codec(uint32_t oid, CodecFormat format):

tests/test_codecs.py

+24
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,30 @@ def hstore_encoder(obj):
900900
DROP EXTENSION hstore
901901
''')
902902

903+
async def test_custom_codec_override(self):
904+
"""Test overriding core codecs."""
905+
import json
906+
907+
conn = await self.cluster.connect(database='postgres', loop=self.loop)
908+
try:
909+
def _encoder(value):
910+
return json.dumps(value).encode('utf-8')
911+
912+
def _decoder(value):
913+
return json.loads(value.decode('utf-8'))
914+
915+
await conn.set_type_codec(
916+
'json', encoder=_encoder, decoder=_decoder,
917+
schema='pg_catalog', binary=True
918+
)
919+
920+
data = {'foo': 'bar', 'spam': 1}
921+
res = await conn.fetchval('SELECT $1::json', data)
922+
self.assertEqual(data, res)
923+
924+
finally:
925+
await conn.close()
926+
903927
async def test_composites_in_arrays(self):
904928
await self.con.execute('''
905929
CREATE TYPE t AS (a text, b int);

0 commit comments

Comments
 (0)