Skip to content

Commit 2e4f0b6

Browse files
committed
Improve diagnostics of invalid executemany() input
This adds a check that elements of sequence passed to `executemany()` are proper sequences themselves and notes the offending sequence element number in the exception message. For example: await self.con.executemany( "INSERT INTO exmany (b) VALUES($1)" [(0,), ("bad",)], ) DataError: invalid input for query argument $1 in element #1 of executemany() sequence: 'bad' ('str' object cannot be interpreted as an integer) Fixes: #807
1 parent 4d39a05 commit 2e4f0b6

File tree

4 files changed

+56
-12
lines changed

4 files changed

+56
-12
lines changed

asyncpg/protocol/prepared_stmt.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cdef class PreparedStatementState:
2929
bint have_text_cols
3030
tuple rows_codecs
3131

32-
cdef _encode_bind_msg(self, args)
32+
cdef _encode_bind_msg(self, args, int seqno = ?)
3333
cpdef _init_codecs(self)
3434
cdef _ensure_rows_decoder(self)
3535
cdef _ensure_args_encoder(self)

asyncpg/protocol/prepared_stmt.pyx

+30-5
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,24 @@ cdef class PreparedStatementState:
101101
def mark_closed(self):
102102
self.closed = True
103103

104-
cdef _encode_bind_msg(self, args):
104+
cdef _encode_bind_msg(self, args, int seqno = -1):
105105
cdef:
106106
int idx
107107
WriteBuffer writer
108108
Codec codec
109109

110+
if not cpython.PySequence_Check(args):
111+
if seqno >= 0:
112+
raise exceptions.DataError(
113+
f'invalid input in executemany() argument sequence '
114+
f'element #{seqno}: expected a sequence, got {type(args)}'
115+
)
116+
else:
117+
# Non executemany() callers do not pass user input directly,
118+
# so bad input is a bug.
119+
raise exceptions.InternalClientError(
120+
f'Bind: expected a sequence, got {type(args)}')
121+
110122
if len(args) > 32767:
111123
raise exceptions.InterfaceError(
112124
'the number of query arguments cannot exceed 32767')
@@ -159,19 +171,32 @@ cdef class PreparedStatementState:
159171
except exceptions.InterfaceError as e:
160172
# This is already a descriptive error, but annotate
161173
# with argument name for clarity.
174+
pos = f'${idx + 1}'
175+
if seqno >= 0:
176+
pos = (
177+
f'{pos} in element #{seqno} of'
178+
f' executemany() sequence'
179+
)
162180
raise e.with_msg(
163-
f'query argument ${idx + 1}: {e.args[0]}') from None
181+
f'query argument {pos}: {e.args[0]}'
182+
) from None
164183
except Exception as e:
165184
# Everything else is assumed to be an encoding error
166185
# due to invalid input.
186+
pos = f'${idx + 1}'
187+
if seqno >= 0:
188+
pos = (
189+
f'{pos} in element #{seqno} of'
190+
f' executemany() sequence'
191+
)
167192
value_repr = repr(arg)
168193
if len(value_repr) > 40:
169194
value_repr = value_repr[:40] + '...'
170195

171196
raise exceptions.DataError(
172-
'invalid input for query argument'
173-
' ${n}: {v} ({msg})'.format(
174-
n=idx + 1, v=value_repr, msg=e)) from e
197+
f'invalid input for query argument'
198+
f' {pos}: {value_repr} ({e})'
199+
) from e
175200

176201
if self.have_text_cols:
177202
writer.write_int16(self.cols_num)

asyncpg/protocol/protocol.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ cdef class BaseProtocol(CoreProtocol):
217217
# Make sure the argument sequence is encoded lazily with
218218
# this generator expression to keep the memory pressure under
219219
# control.
220-
data_gen = (state._encode_bind_msg(b) for b in args)
220+
data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
221221
arg_bufs = iter(data_gen)
222222

223223
waiter = self._new_waiter(timeout)

tests/test_execute.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import asyncpg
1010

1111
from asyncpg import _testbase as tb
12-
from asyncpg.exceptions import UniqueViolationError
12+
from asyncpg import exceptions
1313

1414

1515
class TestExecuteScript(tb.ConnectedTestCase):
@@ -140,6 +140,25 @@ async def test_executemany_basic(self):
140140
])
141141

142142
async def test_executemany_bad_input(self):
143+
with self.assertRaisesRegex(
144+
exceptions.DataError,
145+
r"invalid input in executemany\(\) argument sequence element #1: "
146+
r"expected a sequence",
147+
):
148+
await self.con.executemany('''
149+
INSERT INTO exmany (b) VALUES($1)
150+
''', [(0,), {1: 0}])
151+
152+
with self.assertRaisesRegex(
153+
exceptions.DataError,
154+
r"invalid input for query argument \$1 in element #1 of "
155+
r"executemany\(\) sequence: 'bad'",
156+
):
157+
await self.con.executemany('''
158+
INSERT INTO exmany (b) VALUES($1)
159+
''', [(0,), ("bad",)])
160+
161+
async def test_executemany_error_in_input_gen(self):
143162
bad_data = ([1 / 0] for v in range(10))
144163

145164
with self.assertRaises(ZeroDivisionError):
@@ -155,7 +174,7 @@ async def test_executemany_bad_input(self):
155174
''', good_data)
156175

157176
async def test_executemany_server_failure(self):
158-
with self.assertRaises(UniqueViolationError):
177+
with self.assertRaises(exceptions.UniqueViolationError):
159178
await self.con.executemany('''
160179
INSERT INTO exmany VALUES($1, $2)
161180
''', [
@@ -165,7 +184,7 @@ async def test_executemany_server_failure(self):
165184
self.assertEqual(result, [])
166185

167186
async def test_executemany_server_failure_after_writes(self):
168-
with self.assertRaises(UniqueViolationError):
187+
with self.assertRaises(exceptions.UniqueViolationError):
169188
await self.con.executemany('''
170189
INSERT INTO exmany VALUES($1, $2)
171190
''', [('a' * 32768, x) for x in range(10)] + [
@@ -187,7 +206,7 @@ def gen():
187206
else:
188207
yield 'a' * 32768, pos
189208

190-
with self.assertRaises(UniqueViolationError):
209+
with self.assertRaises(exceptions.UniqueViolationError):
191210
await self.con.executemany('''
192211
INSERT INTO exmany VALUES($1, $2)
193212
''', gen())
@@ -260,7 +279,7 @@ async def test_executemany_client_failure_in_transaction(self):
260279

261280
async def test_executemany_client_server_failure_conflict(self):
262281
self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64)
263-
with self.assertRaises(UniqueViolationError):
282+
with self.assertRaises(exceptions.UniqueViolationError):
264283
await self.con.executemany('''
265284
INSERT INTO exmany VALUES($1, 0)
266285
''', (('a' * 32768,) for y in range(4, -1, -1) if y / y))

0 commit comments

Comments
 (0)