Skip to content

Commit a5413eb

Browse files
committed
Implement binary format codec for numeric type
Fixes: #157
1 parent f1177a0 commit a5413eb

File tree

5 files changed

+354
-6
lines changed

5 files changed

+354
-6
lines changed

asyncpg/protocol/codecs/numeric.pyx

+299-4
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,319 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
from libc.math cimport abs, log10
9+
from libc.stdio cimport snprintf
10+
811
import decimal
912

13+
from asyncpg.protocol cimport python
14+
15+
# defined in postgresql/src/backend/utils/adt/numeric.c
16+
DEF DEC_DIGITS = 4
17+
DEF MAX_DSCALE = 0x3FFF
18+
DEF NUMERIC_POS = 0x0000
19+
DEF NUMERIC_NEG = 0x4000
20+
DEF NUMERIC_NAN = 0xC000
1021

1122
_Dec = decimal.Decimal
1223

1324

14-
cdef numeric_encode(ConnectionSettings settings, WriteBuffer buf, obj):
25+
cdef numeric_encode_text(ConnectionSettings settings, WriteBuffer buf, obj):
1526
text_encode(settings, buf, str(obj))
1627

1728

18-
cdef numeric_decode(ConnectionSettings settings, FastReadBuffer buf):
29+
cdef numeric_decode_text(ConnectionSettings settings, FastReadBuffer buf):
1930
return _Dec(text_decode(settings, buf))
2031

2132

33+
cdef numeric_encode_binary(ConnectionSettings settings, WriteBuffer buf, obj):
34+
cdef:
35+
object dec
36+
object dt
37+
int64_t exponent
38+
int64_t i
39+
int64_t j
40+
tuple pydigits
41+
int64_t num_pydigits
42+
int16_t pgdigit
43+
int64_t num_pgdigits
44+
int16_t dscale
45+
int64_t dweight
46+
int64_t weight
47+
uint16_t sign
48+
int64_t padding_size = 0
49+
50+
if isinstance(obj, _Dec):
51+
dec = obj
52+
else:
53+
dec = _Dec(obj)
54+
55+
dt = dec.as_tuple()
56+
if dt.exponent == 'F':
57+
raise ValueError('numeric type does not support infinite values')
58+
59+
if dt.exponent == 'n' or dt.exponent == 'N':
60+
# NaN
61+
sign = NUMERIC_NAN
62+
num_pgdigits = 0
63+
weight = 0
64+
dscale = 0
65+
else:
66+
exponent = dt.exponent
67+
if exponent < 0 and -exponent > MAX_DSCALE:
68+
raise ValueError(
69+
'cannot encode Decimal value into numeric: '
70+
'exponent is too small')
71+
72+
if dt.sign:
73+
sign = NUMERIC_NEG
74+
else:
75+
sign = NUMERIC_POS
76+
77+
pydigits = dt.digits
78+
num_pydigits = len(pydigits)
79+
80+
dweight = num_pydigits + exponent - 1
81+
if dweight >= 0:
82+
weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1
83+
else:
84+
weight = -((-dweight - 1) // DEC_DIGITS + 1)
85+
86+
if weight > 2 ** 16 - 1:
87+
raise ValueError(
88+
'cannot encode Decimal value into numeric: '
89+
'exponent is too large')
90+
91+
padding_size = \
92+
(weight + 1) * DEC_DIGITS - (dweight + 1)
93+
num_pgdigits = \
94+
(num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS
95+
96+
if num_pgdigits > 2 ** 16 - 1:
97+
raise ValueError(
98+
'cannot encode Decimal value into numeric: '
99+
'number of digits is too large')
100+
101+
# Pad decimal digits to provide room for correct Postgres
102+
# digit alignment in the digit computation loop.
103+
pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS
104+
105+
if exponent < 0:
106+
if -exponent > MAX_DSCALE:
107+
raise ValueError(
108+
'cannot encode Decimal value into numeric: '
109+
'exponent is too small')
110+
dscale = <int16_t>-exponent
111+
else:
112+
dscale = 0
113+
114+
buf.write_int32(2 + 2 + 2 + 2 + 2 * <uint16_t>num_pgdigits)
115+
buf.write_int16(<int16_t>num_pgdigits)
116+
buf.write_int16(<int16_t>weight)
117+
buf.write_int16(<int16_t>sign)
118+
buf.write_int16(dscale)
119+
120+
j = DEC_DIGITS - padding_size
121+
122+
for i in range(num_pgdigits):
123+
pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 +
124+
pydigits[j + 2] * 10 + pydigits[j + 3])
125+
j += DEC_DIGITS
126+
buf.write_int16(pgdigit)
127+
128+
129+
# The decoding strategy here is to form a string representation of
130+
# the numeric var, as it is faster than passing an iterable of digits.
131+
# For this reason the below code is pure overhead and is ~25% slower
132+
# than the simple text decoder above. That said, we need the binary
133+
# decoder to support binary COPY with numeric values.
134+
cdef numeric_decode_binary(ConnectionSettings settings, FastReadBuffer buf):
135+
cdef:
136+
uint16_t num_pgdigits = <uint16_t>hton.unpack_int16(buf.read(2))
137+
int16_t weight = hton.unpack_int16(buf.read(2))
138+
uint16_t sign = <uint16_t>hton.unpack_int16(buf.read(2))
139+
uint16_t dscale = <uint16_t>hton.unpack_int16(buf.read(2))
140+
int16_t pgdigit0
141+
ssize_t i
142+
int16_t pgdigit
143+
object pydigits
144+
ssize_t num_pydigits
145+
ssize_t buf_size
146+
int64_t exponent
147+
int64_t abs_exponent
148+
ssize_t exponent_chars
149+
ssize_t front_padding = 0
150+
ssize_t trailing_padding = 0
151+
ssize_t num_fract_digits
152+
ssize_t dscale_left
153+
char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE]
154+
char *charbuf
155+
char *bufptr
156+
bint buf_allocated = False
157+
158+
if sign == NUMERIC_NAN:
159+
# Not-a-number
160+
return _Dec('NaN')
161+
162+
if num_pgdigits == 0:
163+
# Zero
164+
return _Dec()
165+
166+
pgdigit0 = hton.unpack_int16(buf.read(2))
167+
if weight >= 0:
168+
if pgdigit0 < 10:
169+
front_padding = 3
170+
elif pgdigit0 < 100:
171+
front_padding = 2
172+
elif pgdigit0 < 1000:
173+
front_padding = 1
174+
175+
# Maximum possible number of decimal digits in base 10.
176+
num_pydigits = num_pgdigits * DEC_DIGITS + dscale
177+
# Exponent.
178+
exponent = (weight + 1) * DEC_DIGITS - front_padding
179+
abs_exponent = abs(exponent)
180+
# Number of characters required to render absolute exponent value.
181+
exponent_chars = <ssize_t>log10(<double>abs_exponent) + 1
182+
183+
buf_size = (
184+
1 + # sign
185+
1 + # leading zero
186+
1 + # decimal dot
187+
num_pydigits + # digits
188+
2 + # exponent indicator (E-,E+)
189+
exponent_chars + # exponent
190+
1 # null terminator char
191+
)
192+
193+
if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE:
194+
charbuf = <char *>PyMem_Malloc(<size_t>buf_size)
195+
buf_allocated = True
196+
else:
197+
charbuf = smallbuf
198+
199+
try:
200+
bufptr = charbuf
201+
202+
if sign == NUMERIC_NEG:
203+
bufptr[0] = b'-'
204+
bufptr += 1
205+
206+
bufptr[0] = b'0'
207+
bufptr[1] = b'.'
208+
bufptr += 2
209+
210+
if weight >= 0:
211+
bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0)
212+
else:
213+
bufptr = _unpack_digit(bufptr, pgdigit0)
214+
215+
for i in range(1, num_pgdigits):
216+
pgdigit = hton.unpack_int16(buf.read(2))
217+
bufptr = _unpack_digit(bufptr, pgdigit)
218+
219+
if dscale:
220+
if weight >= 0:
221+
num_fract_digits = num_pgdigits - weight - 1
222+
else:
223+
num_fract_digits = num_pgdigits
224+
225+
# Check how much dscale is left to render (trailing zeros).
226+
dscale_left = dscale - num_fract_digits * DEC_DIGITS
227+
if dscale_left > 0:
228+
for i in range(dscale_left):
229+
bufptr[i] = <char>b'0'
230+
231+
# If display scale is _less_ than the number of rendered digits,
232+
# dscale_left will be negative and this will strip the excess
233+
# trailing zeros.
234+
bufptr += dscale_left
235+
236+
if exponent != 0:
237+
bufptr[0] = b'E'
238+
if exponent < 0:
239+
bufptr[1] = b'-'
240+
else:
241+
bufptr[1] = b'+'
242+
bufptr += 2
243+
snprintf(bufptr, <size_t>exponent_chars + 1, '%d',
244+
<int>abs_exponent)
245+
bufptr += exponent_chars
246+
247+
bufptr[0] = 0
248+
249+
pydigits = python.PyUnicode_FromString(charbuf)
250+
251+
return _Dec(pydigits)
252+
253+
finally:
254+
if buf_allocated:
255+
PyMem_Free(charbuf)
256+
257+
258+
cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit):
259+
cdef:
260+
int64_t d
261+
bint significant
262+
263+
d = pgdigit // 1000
264+
significant = (d > 0)
265+
if significant:
266+
pgdigit -= d * 1000
267+
buf[0] = <char>(d + <int32_t>b'0')
268+
buf += 1
269+
270+
d = pgdigit // 100
271+
significant |= (d > 0)
272+
if significant:
273+
pgdigit -= d * 100
274+
buf[0] = <char>(d + <int32_t>b'0')
275+
buf += 1
276+
277+
d = pgdigit // 10
278+
significant |= (d > 0)
279+
if significant:
280+
pgdigit -= d * 10
281+
buf[0] = <char>(d + <int32_t>b'0')
282+
buf += 1
283+
284+
buf[0] = <char>(pgdigit + <int32_t>b'0')
285+
buf += 1
286+
287+
return buf
288+
289+
290+
cdef inline char *_unpack_digit(char *buf, int64_t pgdigit):
291+
cdef:
292+
int64_t d
293+
294+
d = pgdigit // 1000
295+
pgdigit -= d * 1000
296+
buf[0] = <char>(d + <int32_t>b'0')
297+
298+
d = pgdigit // 100
299+
pgdigit -= d * 100
300+
buf[1] = <char>(d + <int32_t>b'0')
301+
302+
d = pgdigit // 10
303+
pgdigit -= d * 10
304+
buf[2] = <char>(d + <int32_t>b'0')
305+
306+
buf[3] = <char>(pgdigit + <int32_t>b'0')
307+
buf += 4
308+
309+
return buf
310+
311+
22312
cdef init_numeric_codecs():
23313
register_core_codec(NUMERICOID,
24-
<encode_func>&numeric_encode,
25-
<decode_func>&numeric_decode,
314+
<encode_func>&numeric_encode_text,
315+
<decode_func>&numeric_decode_text,
26316
PG_FORMAT_TEXT)
27317

318+
register_core_codec(NUMERICOID,
319+
<encode_func>&numeric_encode_binary,
320+
<decode_func>&numeric_decode_binary,
321+
PG_FORMAT_BINARY)
322+
28323
init_numeric_codecs()

asyncpg/protocol/consts.pxi

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ DEF _MEMORY_FREELIST_SIZE = 1024
1313
DEF _MAXINT32 = 2**31 - 1
1414
DEF _COPY_BUFFER_SIZE = 524288
1515
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
16+
DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256

asyncpg/protocol/python.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cdef extern from "Python.h":
2323
object PyMemoryView_FromMemory(char *mem, ssize_t size, int flags)
2424
object PyMemoryView_GetContiguous(object, int buffertype, char order)
2525

26+
object PyUnicode_FromString(const char *u)
2627
char* PyUnicode_AsUTF8AndSize(object unicode, ssize_t *size) except NULL
2728
char* PyByteArray_AsString(object)
2829
Py_UCS4* PyUnicode_AsUCS4Copy(object)

tests/test_codecs.py

+37
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,43 @@ async def test_interval(self):
452452
res = await self.con.fetchval("SELECT '-5 years -1 month'::interval")
453453
self.assertEqual(res, datetime.timedelta(days=-1855))
454454

455+
async def test_numeric(self):
456+
# Test that we handle dscale correctly.
457+
cases = [
458+
'0.001',
459+
'0.001000',
460+
'1',
461+
'1.00000'
462+
]
463+
464+
for case in cases:
465+
res = await self.con.fetchval(
466+
"SELECT $1::numeric", case)
467+
468+
self.assertEqual(str(res), case)
469+
470+
res = await self.con.fetchval(
471+
"SELECT $1::numeric", decimal.Decimal('NaN'))
472+
self.assertTrue(res.is_nan())
473+
474+
res = await self.con.fetchval(
475+
"SELECT $1::numeric", decimal.Decimal('sNaN'))
476+
self.assertTrue(res.is_nan())
477+
478+
with self.assertRaisesRegex(ValueError, 'numeric type does not '
479+
'support infinite values'):
480+
await self.con.fetchval(
481+
"SELECT $1::numeric", decimal.Decimal('-Inf'))
482+
483+
with self.assertRaisesRegex(ValueError, 'numeric type does not '
484+
'support infinite values'):
485+
await self.con.fetchval(
486+
"SELECT $1::numeric", decimal.Decimal('+Inf'))
487+
488+
with self.assertRaises(decimal.InvalidOperation):
489+
await self.con.fetchval(
490+
"SELECT $1::numeric", 'invalid')
491+
455492
async def test_unhandled_type_fallback(self):
456493
await self.con.execute('''
457494
CREATE EXTENSION IF NOT EXISTS isn

0 commit comments

Comments
 (0)