Skip to content

Commit 08e4c52

Browse files
committed
Make ReadBuffer interface more robust
The current ReadBuffer interface is somewhat error prone. There is no way to "peek" at the next message, so various subprotocols that have nested message processing loops have to resort to the `_skip_discard` kludge on the protocol to inform the main loop that it shouldn't skip over to the next message because the subprotocol already read too much. Fix this by adding a way to iterate over the messages without over-reading, and a way to "put" a message back into the buffer when necessary. This also renames `has_message()` to `take_message()` to make it clear that it changes the buffer state.
1 parent 2fc50e4 commit 08e4c52

File tree

3 files changed

+45
-39
lines changed

3 files changed

+45
-39
lines changed

asyncpg/protocol/buffer.pxd

+5-4
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,14 @@ cdef class ReadBuffer:
105105
cdef inline int32_t read_int32(self) except? -1
106106
cdef inline int16_t read_int16(self) except? -1
107107
cdef inline read_cstr(self)
108-
cdef int32_t has_message(self) except -1
109-
cdef inline int32_t has_message_type(self, char mtype) except -1
108+
cdef int32_t take_message(self) except -1
109+
cdef inline int32_t take_message_type(self, char mtype) except -1
110+
cdef int32_t put_message(self) except -1
110111
cdef inline const char* try_consume_message(self, ssize_t* len)
111112
cdef Memory consume_message(self)
112113
cdef bytearray consume_messages(self, char mtype)
113-
cdef discard_message(self)
114-
cdef inline _discard_message(self)
114+
cdef finish_message(self)
115+
cdef inline _finish_message(self)
115116
cdef inline char get_message_type(self)
116117
cdef inline int32_t get_message_length(self)
117118

asyncpg/protocol/buffer.pyx

+30-16
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ cdef class ReadBuffer:
489489

490490
self._ensure_first_buf()
491491

492-
cdef int32_t has_message(self) except -1:
492+
cdef int32_t take_message(self) except -1:
493493
cdef:
494494
const char *cbuf
495495

@@ -525,8 +525,24 @@ cdef class ReadBuffer:
525525
self._current_message_ready = 1
526526
return 1
527527

528-
cdef inline int32_t has_message_type(self, char mtype) except -1:
529-
return self.has_message() and self.get_message_type() == mtype
528+
cdef inline int32_t take_message_type(self, char mtype) except -1:
529+
cdef const char *buf0
530+
531+
if self._current_message_ready:
532+
return self._current_message_type == mtype
533+
elif self._length >= 1:
534+
self._ensure_first_buf()
535+
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
536+
537+
return buf0[self._pos0] == mtype and self.take_message()
538+
else:
539+
return 0
540+
541+
cdef int32_t put_message(self) except -1:
542+
if not self._current_message_ready:
543+
raise BufferError('cannot put message: no message taken')
544+
self._current_message_ready = False
545+
return 0
530546

531547
cdef inline const char* try_consume_message(self, ssize_t* len):
532548
cdef:
@@ -541,7 +557,7 @@ cdef class ReadBuffer:
541557
buf = self._try_read_bytes(buf_len)
542558
if buf != NULL:
543559
len[0] = buf_len
544-
self._discard_message()
560+
self._finish_message()
545561
return buf
546562

547563
cdef Memory consume_message(self):
@@ -551,7 +567,7 @@ cdef class ReadBuffer:
551567
mem = self.read(self._current_message_len_unread)
552568
else:
553569
mem = None
554-
self._discard_message()
570+
self._finish_message()
555571
return mem
556572

557573
cdef bytearray consume_messages(self, char mtype):
@@ -562,7 +578,7 @@ cdef class ReadBuffer:
562578
ssize_t total_bytes = 0
563579
bytearray result
564580

565-
if not self.has_message_type(mtype):
581+
if not self.take_message_type(mtype):
566582
return None
567583

568584
# consume_messages is a volume-oriented method, so
@@ -571,26 +587,24 @@ cdef class ReadBuffer:
571587
result = PyByteArray_FromStringAndSize(NULL, self._length)
572588
buf = PyByteArray_AsString(result)
573589

574-
while self.has_message_type(mtype):
590+
while self.take_message_type(mtype):
575591
nbytes = self._current_message_len_unread
576592
self._read(buf, nbytes)
577593
buf += nbytes
578594
total_bytes += nbytes
579-
self._discard_message()
595+
self._finish_message()
580596

581597
# Clamp the result to an actual size read.
582598
PyByteArray_Resize(result, total_bytes)
583599

584600
return result
585601

586-
cdef discard_message(self):
587-
if self._current_message_type == 0:
588-
# Already discarded
602+
cdef finish_message(self):
603+
if self._current_message_type == 0 or not self._current_message_ready:
604+
# No message has not been finished (e.g by consume_message()),
605+
# or has been put back by put_message().
589606
return
590607

591-
if not self._current_message_ready:
592-
raise BufferError('no message to discard')
593-
594608
if self._current_message_len_unread:
595609
if ASYNCPG_DEBUG:
596610
mtype = chr(self._current_message_type)
@@ -602,9 +616,9 @@ cdef class ReadBuffer:
602616
mtype,
603617
(<Memory>discarded).as_bytes()))
604618

605-
self._discard_message()
619+
self._finish_message()
606620

607-
cdef inline _discard_message(self):
621+
cdef inline _finish_message(self):
608622
self._current_message_type = 0
609623
self._current_message_len = 0
610624
self._current_message_ready = 0

asyncpg/protocol/coreproto.pyx

+10-19
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ cdef class CoreProtocol:
2222
self.xact_status = PQTRANS_IDLE
2323
self.encoding = 'utf-8'
2424

25-
self._skip_discard = False
26-
2725
# executemany support data
2826
self._execute_iter = None
2927
self._execute_portal_name = None
@@ -36,7 +34,7 @@ cdef class CoreProtocol:
3634
char mtype
3735
ProtocolState state
3836

39-
while self.buffer.has_message() == 1:
37+
while self.buffer.take_message() == 1:
4038
mtype = self.buffer.get_message_type()
4139
state = self.state
4240

@@ -131,10 +129,7 @@ cdef class CoreProtocol:
131129
self.state = PROTOCOL_ERROR_CONSUME
132130

133131
finally:
134-
if self._skip_discard:
135-
self._skip_discard = False
136-
else:
137-
self.buffer.discard_message()
132+
self.buffer.finish_message()
138133

139134
cdef _process__auth(self, char mtype):
140135
if mtype == b'R':
@@ -402,18 +397,20 @@ cdef class CoreProtocol:
402397

403398
self.result = buf.consume_messages(b'd')
404399

405-
self._skip_discard = True
406-
407400
# By this point we have consumed all CopyData messages
408401
# in the inbound buffer. If there are no messages left
409402
# in the buffer, we need to push the accumulated data
410403
# out to the caller in anticipation of the new CopyData
411404
# batch. If there _are_ non-CopyData messages left,
412405
# we must not push the result here and let the
413406
# _process__copy_out_data subprotocol do the job.
414-
if not buf.has_message():
407+
if not buf.take_message():
415408
self._on_result()
416409
self.result = None
410+
else:
411+
# If there is a message in the buffer, put it back to
412+
# be processed by the next protocol iteration.
413+
buf.put_message()
417414

418415
cdef _write_copy_data_msg(self, object data):
419416
cdef:
@@ -468,11 +465,9 @@ cdef class CoreProtocol:
468465
'_parse_data_msgs: first message is not "D"')
469466

470467
if self._discard_data:
471-
while True:
468+
while buf.take_message_type(b'D'):
472469
buf.consume_message()
473-
if not buf.has_message() or buf.get_message_type() != b'D':
474-
self._skip_discard = True
475-
return
470+
return
476471

477472
if ASYNCPG_DEBUG:
478473
if type(self.result) is not list:
@@ -481,7 +476,7 @@ cdef class CoreProtocol:
481476
format(self.result))
482477

483478
rows = self.result
484-
while True:
479+
while buf.take_message_type(b'D'):
485480
cbuf = buf.try_consume_message(&cbuf_len)
486481
if cbuf != NULL:
487482
row = decoder(self, cbuf, cbuf_len)
@@ -491,10 +486,6 @@ cdef class CoreProtocol:
491486

492487
cpython.PyList_Append(rows, row)
493488

494-
if not buf.has_message() or buf.get_message_type() != b'D':
495-
self._skip_discard = True
496-
return
497-
498489
cdef _parse_msg_backend_key_data(self):
499490
self.backend_pid = self.buffer.read_int32()
500491
self.backend_secret = self.buffer.read_int32()

0 commit comments

Comments
 (0)