Skip to content

Commit acd104d

Browse files
committed
BrokerConnection.receive_bytes(data) -> response events
1 parent 9cd39bf commit acd104d

File tree

3 files changed

+91
-93
lines changed

3 files changed

+91
-93
lines changed

kafka/client_async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -603,25 +603,14 @@ def _poll(self, timeout, sleep=True):
603603
continue
604604

605605
self._idle_expiry_manager.update(conn.node_id)
606-
607-
# Accumulate as many responses as the connection has pending
608-
while conn.in_flight_requests:
609-
response = conn.recv() # Note: conn.recv runs callbacks / errbacks
610-
611-
# Incomplete responses are buffered internally
612-
# while conn.in_flight_requests retains the request
613-
if not response:
614-
break
615-
responses.append(response)
606+
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
616607

617608
# Check for additional pending SSL bytes
618609
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
619610
# TODO: optimize
620611
for conn in self._conns.values():
621612
if conn not in processed and conn.connected() and conn._sock.pending():
622-
response = conn.recv()
623-
if response:
624-
responses.append(response)
613+
responses.extend(conn.recv())
625614

626615
for conn in six.itervalues(self._conns):
627616
if conn.requests_timed_out():
@@ -633,6 +622,7 @@ def _poll(self, timeout, sleep=True):
633622

634623
if self._sensors:
635624
self._sensors.io_time.record((time.time() - end_select) * 1000000000)
625+
636626
self._maybe_close_oldest_connection()
637627
return responses
638628

kafka/conn.py

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import errno
66
import logging
7-
import io
87
from random import shuffle, uniform
98
import socket
109
import time
@@ -18,6 +17,7 @@
1817
from kafka.protocol.api import RequestHeader
1918
from kafka.protocol.admin import SaslHandShakeRequest
2019
from kafka.protocol.commit import GroupCoordinatorResponse
20+
from kafka.protocol.frame import KafkaBytes
2121
from kafka.protocol.metadata import MetadataRequest
2222
from kafka.protocol.types import Int32
2323
from kafka.version import __version__
@@ -214,9 +214,9 @@ def __init__(self, host, port, afi, **configs):
214214
if self.config['ssl_context'] is not None:
215215
self._ssl_context = self.config['ssl_context']
216216
self._sasl_auth_future = None
217-
self._rbuffer = io.BytesIO()
217+
self._header = KafkaBytes(4)
218+
self._rbuffer = None
218219
self._receiving = False
219-
self._next_payload_bytes = 0
220220
self.last_attempt = 0
221221
self._processing = False
222222
self._correlation_id = 0
@@ -552,17 +552,19 @@ def close(self, error=None):
552552
self.state = ConnectionStates.DISCONNECTED
553553
self.last_attempt = time.time()
554554
self._sasl_auth_future = None
555-
self._receiving = False
556-
self._next_payload_bytes = 0
557-
self._rbuffer.seek(0)
558-
self._rbuffer.truncate()
555+
self._reset_buffer()
559556
if error is None:
560557
error = Errors.Cancelled(str(self))
561558
while self.in_flight_requests:
562559
ifr = self.in_flight_requests.popleft()
563560
ifr.future.failure(error)
564561
self.config['state_change_callback'](self)
565562

563+
def _reset_buffer(self):
564+
self._receiving = False
565+
self._header.seek(0)
566+
self._rbuffer = None
567+
566568
def send(self, request):
567569
"""send request, return Future()
568570
@@ -636,11 +638,11 @@ def recv(self):
636638
# fail all the pending request futures
637639
if self.in_flight_requests:
638640
self.close(Errors.ConnectionError('Socket not connected during recv with in-flight-requests'))
639-
return None
641+
return ()
640642

641643
elif not self.in_flight_requests:
642644
log.warning('%s: No in-flight-requests to recv', self)
643-
return None
645+
return ()
644646

645647
response = self._recv()
646648
if not response and self.requests_timed_out():
@@ -649,103 +651,108 @@ def recv(self):
649651
self.close(error=Errors.RequestTimedOutError(
650652
'Request timed out after %s ms' %
651653
self.config['request_timeout_ms']))
652-
return None
654+
return ()
653655
return response
654656

655657
def _recv(self):
656-
# Not receiving is the state of reading the payload header
657-
if not self._receiving:
658+
responses = []
659+
SOCK_CHUNK_BYTES = 4096
660+
while True:
658661
try:
659-
bytes_to_read = 4 - self._rbuffer.tell()
660-
data = self._sock.recv(bytes_to_read)
662+
data = self._sock.recv(SOCK_CHUNK_BYTES)
661663
# We expect socket.recv to raise an exception if there is not
662664
# enough data to read the full bytes_to_read
663665
# but if the socket is disconnected, we will get empty data
664666
# without an exception raised
665667
if not data:
666668
log.error('%s: socket disconnected', self)
667669
self.close(error=Errors.ConnectionError('socket disconnected'))
668-
return None
669-
self._rbuffer.write(data)
670+
break
671+
else:
672+
responses.extend(self.receive_bytes(data))
673+
if len(data) < SOCK_CHUNK_BYTES:
674+
break
670675
except SSLWantReadError:
671-
return None
676+
break
672677
except ConnectionError as e:
673678
if six.PY2 and e.errno == errno.EWOULDBLOCK:
674-
return None
675-
log.exception('%s: Error receiving 4-byte payload header -'
679+
break
680+
log.exception('%s: Error receiving network data'
676681
' closing socket', self)
677682
self.close(error=Errors.ConnectionError(e))
678-
return None
679-
except BlockingIOError:
680-
if six.PY3:
681-
return None
682-
raise
683-
684-
if self._rbuffer.tell() == 4:
685-
self._rbuffer.seek(0)
686-
self._next_payload_bytes = Int32.decode(self._rbuffer)
687-
# reset buffer and switch state to receiving payload bytes
688-
self._rbuffer.seek(0)
689-
self._rbuffer.truncate()
690-
self._receiving = True
691-
elif self._rbuffer.tell() > 4:
692-
raise Errors.KafkaError('this should not happen - are you threading?')
693-
694-
if self._receiving:
695-
staged_bytes = self._rbuffer.tell()
696-
try:
697-
bytes_to_read = self._next_payload_bytes - staged_bytes
698-
data = self._sock.recv(bytes_to_read)
699-
# We expect socket.recv to raise an exception if there is not
700-
# enough data to read the full bytes_to_read
701-
# but if the socket is disconnected, we will get empty data
702-
# without an exception raised
703-
if bytes_to_read and not data:
704-
log.error('%s: socket disconnected', self)
705-
self.close(error=Errors.ConnectionError('socket disconnected'))
706-
return None
707-
self._rbuffer.write(data)
708-
except SSLWantReadError:
709-
return None
710-
except ConnectionError as e:
711-
# Extremely small chance that we have exactly 4 bytes for a
712-
# header, but nothing to read in the body yet
713-
if six.PY2 and e.errno == errno.EWOULDBLOCK:
714-
return None
715-
log.exception('%s: Error in recv', self)
716-
self.close(error=Errors.ConnectionError(e))
717-
return None
683+
break
718684
except BlockingIOError:
719685
if six.PY3:
720-
return None
686+
break
721687
raise
688+
return responses
722689

723-
staged_bytes = self._rbuffer.tell()
724-
if staged_bytes > self._next_payload_bytes:
725-
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
726-
727-
if staged_bytes != self._next_payload_bytes:
728-
return None
690+
def receive_bytes(self, data):
691+
i = 0
692+
n = len(data)
693+
responses = []
694+
if self._sensors:
695+
self._sensors.bytes_received.record(n)
696+
while i < n:
697+
698+
# Not receiving is the state of reading the payload header
699+
if not self._receiving:
700+
bytes_to_read = min(4 - self._header.tell(), n - i)
701+
self._header.write(data[i:i+bytes_to_read])
702+
i += bytes_to_read
703+
704+
if self._header.tell() == 4:
705+
self._header.seek(0)
706+
nbytes = Int32.decode(self._header)
707+
# reset buffer and switch state to receiving payload bytes
708+
self._rbuffer = KafkaBytes(nbytes)
709+
self._receiving = True
710+
elif self._header.tell() > 4:
711+
raise Errors.KafkaError('this should not happen - are you threading?')
712+
713+
714+
if self._receiving:
715+
total_bytes = len(self._rbuffer)
716+
staged_bytes = self._rbuffer.tell()
717+
bytes_to_read = min(total_bytes - staged_bytes, n - i)
718+
self._rbuffer.write(data[i:i+bytes_to_read])
719+
i += bytes_to_read
720+
721+
staged_bytes = self._rbuffer.tell()
722+
if staged_bytes > total_bytes:
723+
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
724+
725+
if staged_bytes != total_bytes:
726+
break
729727

730-
self._receiving = False
731-
self._next_payload_bytes = 0
732-
if self._sensors:
733-
self._sensors.bytes_received.record(4 + self._rbuffer.tell())
734-
self._rbuffer.seek(0)
735-
response = self._process_response(self._rbuffer)
736-
self._rbuffer.seek(0)
737-
self._rbuffer.truncate()
738-
return response
728+
self._receiving = False
729+
self._rbuffer.seek(0)
730+
resp = self._process_response(self._rbuffer)
731+
if resp is not None:
732+
responses.append(resp)
733+
self._reset_buffer()
734+
return responses
739735

740736
def _process_response(self, read_buffer):
741737
assert not self._processing, 'Recursion not supported'
742738
self._processing = True
743-
ifr = self.in_flight_requests.popleft()
739+
recv_correlation_id = Int32.decode(read_buffer)
740+
741+
if not self.in_flight_requests:
742+
error = Errors.CorrelationIdError(
743+
'%s: No in-flight-request found for server response'
744+
' with correlation ID %d'
745+
% (self, recv_correlation_id))
746+
self.close(error)
747+
self._processing = False
748+
return None
749+
else:
750+
ifr = self.in_flight_requests.popleft()
751+
744752
if self._sensors:
745753
self._sensors.request_time.record((time.time() - ifr.timestamp) * 1000)
746754

747755
# verify send/recv correlation ids match
748-
recv_correlation_id = Int32.decode(read_buffer)
749756

750757
# 0.8.2 quirk
751758
if (self.config['api_version'] == (0, 8, 2) and

kafka/protocol/message.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..codec import (has_gzip, has_snappy, has_lz4,
77
gzip_decode, snappy_decode,
88
lz4_decode, lz4_decode_old_kafka)
9+
from .frame import KafkaBytes
910
from .struct import Struct
1011
from .types import (
1112
Int8, Int32, Int64, Bytes, Schema, AbstractType
@@ -155,10 +156,10 @@ class MessageSet(AbstractType):
155156
@classmethod
156157
def encode(cls, items):
157158
# RecordAccumulator encodes messagesets internally
158-
if isinstance(items, io.BytesIO):
159+
if isinstance(items, (io.BytesIO, KafkaBytes)):
159160
size = Int32.decode(items)
160161
# rewind and return all the bytes
161-
items.seek(-4, 1)
162+
items.seek(items.tell() - 4)
162163
return items.read(size + 4)
163164

164165
encoded_values = []
@@ -198,7 +199,7 @@ def decode(cls, data, bytes_to_read=None):
198199

199200
@classmethod
200201
def repr(cls, messages):
201-
if isinstance(messages, io.BytesIO):
202+
if isinstance(messages, (KafkaBytes, io.BytesIO)):
202203
offset = messages.tell()
203204
decoded = cls.decode(messages)
204205
messages.seek(offset)

0 commit comments

Comments
 (0)