Skip to content

Use SubscriptionType to track topics/pattern/user assignment #2565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions kafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from collections import Sequence
except ImportError:
from collections.abc import Sequence
try:
# enum in stdlib as of py3.4
from enum import IntEnum # pylint: disable=import-error
except ImportError:
# vendored backport module
from kafka.vendor.enum34 import IntEnum
import logging
import random
import re
Expand All @@ -20,6 +26,13 @@
log = logging.getLogger(__name__)


class SubscriptionType(IntEnum):
NONE = 0
AUTO_TOPICS = 1
AUTO_PATTERN = 2
USER_ASSIGNED = 3


class SubscriptionState(object):
"""
A class for tracking the topics, partitions, and offsets for the consumer.
Expand Down Expand Up @@ -67,6 +80,7 @@ def __init__(self, offset_reset_strategy='earliest'):
self._default_offset_reset_strategy = offset_reset_strategy

self.subscription = None # set() or None
self.subscription_type = SubscriptionType.NONE
self.subscribed_pattern = None # regex str or None
self._group_subscription = set()
self._user_assignment = set()
Expand All @@ -76,6 +90,14 @@ def __init__(self, offset_reset_strategy='earliest'):
# initialize to true for the consumers to fetch offset upon starting up
self.needs_fetch_committed_offsets = True

def _set_subscription_type(self, subscription_type):
if not isinstance(subscription_type, SubscriptionType):
raise ValueError('SubscriptionType enum required')
if self.subscription_type == SubscriptionType.NONE:
self.subscription_type = subscription_type
elif self.subscription_type != subscription_type:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

def subscribe(self, topics=(), pattern=None, listener=None):
"""Subscribe to a list of topics, or a topic regex pattern.

Expand Down Expand Up @@ -111,17 +133,19 @@ def subscribe(self, topics=(), pattern=None, listener=None):
guaranteed, however, that the partitions revoked/assigned
through this interface are from topics subscribed in this call.
"""
if self._user_assignment or (topics and pattern):
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert topics or pattern, 'Must provide topics or pattern'
if (topics and pattern):
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

if pattern:
elif pattern:
self._set_subscription_type(SubscriptionType.AUTO_PATTERN)
log.info('Subscribing to pattern: /%s/', pattern)
self.subscription = set()
self.subscribed_pattern = re.compile(pattern)
else:
if isinstance(topics, str) or not isinstance(topics, Sequence):
raise TypeError('Topics must be a list (or non-str sequence)')
self._set_subscription_type(SubscriptionType.AUTO_TOPICS)
self.change_subscription(topics)

if listener and not isinstance(listener, ConsumerRebalanceListener):
Expand All @@ -141,7 +165,7 @@ def change_subscription(self, topics):
- a topic name is '.' or '..' or
- a topic name does not consist of ASCII-characters/'-'/'_'/'.'
"""
if self._user_assignment:
if not self.partitions_auto_assigned():
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

if isinstance(topics, six.string_types):
Expand All @@ -168,13 +192,13 @@ def group_subscribe(self, topics):
Arguments:
topics (list of str): topics to add to the group subscription
"""
if self._user_assignment:
if not self.partitions_auto_assigned():
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
self._group_subscription.update(topics)

def reset_group_subscription(self):
"""Reset the group's subscription to only contain topics subscribed by this consumer."""
if self._user_assignment:
if not self.partitions_auto_assigned():
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert self.subscription is not None, 'Subscription required'
self._group_subscription.intersection_update(self.subscription)
Expand All @@ -197,9 +221,7 @@ def assign_from_user(self, partitions):
Raises:
IllegalStateError: if consumer has already called subscribe()
"""
if self.subscription is not None:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

self._set_subscription_type(SubscriptionType.USER_ASSIGNED)
if self._user_assignment != set(partitions):
self._user_assignment = set(partitions)
self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState())
Expand Down Expand Up @@ -250,6 +272,7 @@ def unsubscribe(self):
self._user_assignment.clear()
self.assignment.clear()
self.subscribed_pattern = None
self.subscription_type = SubscriptionType.NONE

def group_subscription(self):
"""Get the topic subscription for the group.
Expand Down Expand Up @@ -300,7 +323,7 @@ def fetchable_partitions(self):

def partitions_auto_assigned(self):
"""Return True unless user supplied partitions manually."""
return self.subscription is not None
return self.subscription_type in (SubscriptionType.AUTO_TOPICS, SubscriptionType.AUTO_PATTERN)

def all_consumed_offsets(self):
"""Returns consumed offsets as {TopicPartition: OffsetAndMetadata}"""
Expand Down
4 changes: 2 additions & 2 deletions test/test_consumer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_kafka_consumer_unsupported_encoding(
def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
TIMEOUT_MS = 500
consumer = kafka_consumer_factory(auto_offset_reset='earliest',
enable_auto_commit=False,
consumer_timeout_ms=TIMEOUT_MS)
enable_auto_commit=False,
consumer_timeout_ms=TIMEOUT_MS)

# Manual assignment avoids overhead of consumer group mgmt
consumer.unsubscribe()
Expand Down
1 change: 1 addition & 0 deletions test/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_subscription_listener_failure(mocker, coordinator):


def test_perform_assignment(mocker, coordinator):
coordinator._subscription.subscribe(topics=['foo1'])
member_metadata = {
'member-foo': ConsumerProtocolMemberMetadata(0, ['foo1'], b''),
'member-bar': ConsumerProtocolMemberMetadata(0, ['foo1'], b'')
Expand Down
Loading