Skip to content

Make ReadBuffer interface more robust #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions asyncpg/protocol/buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ cdef class ReadBuffer:
cdef inline int32_t read_int32(self) except? -1
cdef inline int16_t read_int16(self) except? -1
cdef inline read_cstr(self)
cdef int32_t has_message(self) except -1
cdef inline int32_t has_message_type(self, char mtype) except -1
cdef int32_t take_message(self) except -1
cdef inline int32_t take_message_type(self, char mtype) except -1
cdef int32_t put_message(self) except -1
cdef inline const char* try_consume_message(self, ssize_t* len)
cdef Memory consume_message(self)
cdef bytearray consume_messages(self, char mtype)
cdef discard_message(self)
cdef inline _discard_message(self)
cdef finish_message(self)
cdef inline _finish_message(self)
cdef inline char get_message_type(self)
cdef inline int32_t get_message_length(self)

Expand Down
46 changes: 30 additions & 16 deletions asyncpg/protocol/buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ cdef class ReadBuffer:

self._ensure_first_buf()

cdef int32_t has_message(self) except -1:
cdef int32_t take_message(self) except -1:
cdef:
const char *cbuf

Expand Down Expand Up @@ -525,8 +525,24 @@ cdef class ReadBuffer:
self._current_message_ready = 1
return 1

cdef inline int32_t has_message_type(self, char mtype) except -1:
return self.has_message() and self.get_message_type() == mtype
cdef inline int32_t take_message_type(self, char mtype) except -1:
cdef const char *buf0

if self._current_message_ready:
return self._current_message_type == mtype
elif self._length >= 1:
self._ensure_first_buf()
buf0 = cpython.PyBytes_AS_STRING(self._buf0)

return buf0[self._pos0] == mtype and self.take_message()
else:
return 0

cdef int32_t put_message(self) except -1:
if not self._current_message_ready:
raise BufferError('cannot put message: no message taken')
self._current_message_ready = False
return 0

cdef inline const char* try_consume_message(self, ssize_t* len):
cdef:
Expand All @@ -541,7 +557,7 @@ cdef class ReadBuffer:
buf = self._try_read_bytes(buf_len)
if buf != NULL:
len[0] = buf_len
self._discard_message()
self._finish_message()
return buf

cdef Memory consume_message(self):
Expand All @@ -551,7 +567,7 @@ cdef class ReadBuffer:
mem = self.read(self._current_message_len_unread)
else:
mem = None
self._discard_message()
self._finish_message()
return mem

cdef bytearray consume_messages(self, char mtype):
Expand All @@ -562,7 +578,7 @@ cdef class ReadBuffer:
ssize_t total_bytes = 0
bytearray result

if not self.has_message_type(mtype):
if not self.take_message_type(mtype):
return None

# consume_messages is a volume-oriented method, so
Expand All @@ -571,26 +587,24 @@ cdef class ReadBuffer:
result = PyByteArray_FromStringAndSize(NULL, self._length)
buf = PyByteArray_AsString(result)

while self.has_message_type(mtype):
while self.take_message_type(mtype):
nbytes = self._current_message_len_unread
self._read(buf, nbytes)
buf += nbytes
total_bytes += nbytes
self._discard_message()
self._finish_message()

# Clamp the result to an actual size read.
PyByteArray_Resize(result, total_bytes)

return result

cdef discard_message(self):
if self._current_message_type == 0:
# Already discarded
cdef finish_message(self):
if self._current_message_type == 0 or not self._current_message_ready:
# The message has already been finished (e.g by consume_message()),
# or has been put back by put_message().
return

if not self._current_message_ready:
raise BufferError('no message to discard')

if self._current_message_len_unread:
if ASYNCPG_DEBUG:
mtype = chr(self._current_message_type)
Expand All @@ -602,9 +616,9 @@ cdef class ReadBuffer:
mtype,
(<Memory>discarded).as_bytes()))

self._discard_message()
self._finish_message()

cdef inline _discard_message(self):
cdef inline _finish_message(self):
self._current_message_type = 0
self._current_message_len = 0
self._current_message_ready = 0
Expand Down
29 changes: 10 additions & 19 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ cdef class CoreProtocol:
self.xact_status = PQTRANS_IDLE
self.encoding = 'utf-8'

self._skip_discard = False

# executemany support data
self._execute_iter = None
self._execute_portal_name = None
Expand All @@ -36,7 +34,7 @@ cdef class CoreProtocol:
char mtype
ProtocolState state

while self.buffer.has_message() == 1:
while self.buffer.take_message() == 1:
mtype = self.buffer.get_message_type()
state = self.state

Expand Down Expand Up @@ -150,10 +148,7 @@ cdef class CoreProtocol:
self._push_result()

finally:
if self._skip_discard:
self._skip_discard = False
else:
self.buffer.discard_message()
self.buffer.finish_message()

cdef _process__auth(self, char mtype):
if mtype == b'R':
Expand Down Expand Up @@ -319,18 +314,20 @@ cdef class CoreProtocol:

self.result = buf.consume_messages(b'd')

self._skip_discard = True

# By this point we have consumed all CopyData messages
# in the inbound buffer. If there are no messages left
# in the buffer, we need to push the accumulated data
# out to the caller in anticipation of the new CopyData
# batch. If there _are_ non-CopyData messages left,
# we must not push the result here and let the
# _process__copy_out_data subprotocol do the job.
if not buf.has_message():
if not buf.take_message():
self._on_result()
self.result = None
else:
# If there is a message in the buffer, put it back to
# be processed by the next protocol iteration.
buf.put_message()

cdef _write_copy_data_msg(self, object data):
cdef:
Expand Down Expand Up @@ -385,11 +382,9 @@ cdef class CoreProtocol:
'_parse_data_msgs: first message is not "D"')

if self._discard_data:
while True:
while buf.take_message_type(b'D'):
buf.consume_message()
if not buf.has_message() or buf.get_message_type() != b'D':
self._skip_discard = True
return
return

if ASYNCPG_DEBUG:
if type(self.result) is not list:
Expand All @@ -398,7 +393,7 @@ cdef class CoreProtocol:
format(self.result))

rows = self.result
while True:
while buf.take_message_type(b'D'):
cbuf = buf.try_consume_message(&cbuf_len)
if cbuf != NULL:
row = decoder(self, cbuf, cbuf_len)
Expand All @@ -408,10 +403,6 @@ cdef class CoreProtocol:

cpython.PyList_Append(rows, row)

if not buf.has_message() or buf.get_message_type() != b'D':
self._skip_discard = True
return

cdef _parse_msg_backend_key_data(self):
self.backend_pid = self.buffer.read_int32()
self.backend_secret = self.buffer.read_int32()
Expand Down