Skip to content

KIP-98: Add offsets support to transactional KafkaProducer #2590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,36 @@ def begin_transaction(self):
raise Errors.IllegalStateError("Cannot use transactional methods without enabling transactions")
self._transaction_manager.begin_transaction()

def send_offsets_to_transaction(self, offsets, consumer_group_id):
"""
Sends a list of consumed offsets to the consumer group coordinator, and also marks
those offsets as part of the current transaction. These offsets will be considered
consumed only if the transaction is committed successfully.

This method should be used when you need to batch consumed and produced messages
together, typically in a consume-transform-produce pattern.

Arguments:
offsets ({TopicPartition: OffsetAndMetadata}): map of topic-partition -> offsets to commit
as part of current transaction.
consumer_group_id (str): Name of consumer group for offsets commit.

Raises:
IllegalStateError: if no transactional_id, or transaction has not been started.
ProducerFencedError: fatal error indicating another producer with the same transactional_id is active.
UnsupportedVersionError: fatal error indicating the broker does not support transactions (i.e. if < 0.11).
UnsupportedForMessageFormatError: fatal error indicating the message format used for the offsets
topic on the broker does not support transactions.
AuthorizationError: fatal error indicating that the configured transactional_id is not authorized.
KafkaErro:r if the producer has encountered a previous fatal or abortable error, or for any
other unexpected error
"""
if not self._transaction_manager:
raise Errors.IllegalStateError("Cannot use transactional methods without enabling transactions")
result = self._transaction_manager.send_offsets_to_transaction(offsets, consumer_group_id)
self._sender.wakeup()
result.wait()

def commit_transaction(self):
""" Commits the ongoing transaction.

Expand Down
198 changes: 182 additions & 16 deletions kafka/producer/transaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from kafka.vendor.enum34 import IntEnum

import kafka.errors as Errors
from kafka.protocol.add_offsets_to_txn import AddOffsetsToTxnRequest
from kafka.protocol.add_partitions_to_txn import AddPartitionsToTxnRequest
from kafka.protocol.end_txn import EndTxnRequest
from kafka.protocol.find_coordinator import FindCoordinatorRequest
from kafka.protocol.init_producer_id import InitProducerIdRequest
from kafka.protocol.txn_offset_commit import TxnOffsetCommitRequest
from kafka.structs import TopicPartition


Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, transactional_id=None, transaction_timeout_ms=0, retry_backof
self._new_partitions_in_transaction = set()
self._pending_partitions_in_transaction = set()
self._partitions_in_transaction = set()
self._pending_txn_offset_commits = dict()

self._current_state = TransactionState.UNINITIALIZED
self._last_error = None
Expand All @@ -138,7 +141,7 @@ def initialize_transactions(self):
self._transition_to(TransactionState.INITIALIZING)
self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH))
self._sequence_numbers.clear()
handler = InitProducerIdHandler(self, self.transactional_id, self.transaction_timeout_ms)
handler = InitProducerIdHandler(self, self.transaction_timeout_ms)
self._enqueue_request(handler)
return handler.result

Expand Down Expand Up @@ -169,10 +172,22 @@ def begin_abort(self):
def _begin_completing_transaction(self, committed):
if self._new_partitions_in_transaction:
self._enqueue_request(self._add_partitions_to_transaction_handler())
handler = EndTxnHandler(self, self.transactional_id, self.producer_id_and_epoch.producer_id, self.producer_id_and_epoch.epoch, committed)
handler = EndTxnHandler(self, committed)
self._enqueue_request(handler)
return handler.result

def send_offsets_to_transaction(self, offsets, consumer_group_id):
with self._lock:
self._ensure_transactional()
self._maybe_fail_with_error()
if self._current_state != TransactionState.IN_TRANSACTION:
raise Errors.KafkaError("Cannot send offsets to transaction because the producer is not in an active transaction")

log.debug("Begin adding offsets %s for consumer group %s to transaction", offsets, consumer_group_id)
handler = AddOffsetsToTxnHandler(self, consumer_group_id, offsets)
self._enqueue_request(handler)
return handler.result

def maybe_add_partition_to_transaction(self, topic_partition):
with self._lock:
self._fail_if_not_ready_for_send()
Expand Down Expand Up @@ -389,6 +404,10 @@ def _test_transaction_contains_partition(self, tp):
with self._lock:
return tp in self._partitions_in_transaction

# visible for testing
def _test_has_pending_offset_commits(self):
return bool(self._pending_txn_offset_commits)

# visible for testing
def _test_has_ongoing_transaction(self):
with self._lock:
Expand Down Expand Up @@ -473,7 +492,7 @@ def _add_partitions_to_transaction_handler(self):
with self._lock:
self._pending_partitions_in_transaction.update(self._new_partitions_in_transaction)
self._new_partitions_in_transaction.clear()
return AddPartitionsToTxnHandler(self, self.transactional_id, self.producer_id_and_epoch.producer_id, self.producer_id_and_epoch.epoch, self._pending_partitions_in_transaction)
return AddPartitionsToTxnHandler(self, self._pending_partitions_in_transaction)


class TransactionalRequestResult(object):
Expand Down Expand Up @@ -518,6 +537,18 @@ def __init__(self, transaction_manager, result=None):
self._result = result or TransactionalRequestResult()
self._is_retry = False

@property
def transactional_id(self):
return self.transaction_manager.transactional_id

@property
def producer_id(self):
return self.transaction_manager.producer_id_and_epoch.producer_id

@property
def producer_epoch(self):
return self.transaction_manager.producer_id_and_epoch.epoch

def fatal_error(self, exc):
self.transaction_manager._transition_to_fatal_error(exc)
self._result.done(error=exc)
Expand Down Expand Up @@ -585,16 +616,15 @@ def priority(self):


class InitProducerIdHandler(TxnRequestHandler):
def __init__(self, transaction_manager, transactional_id, transaction_timeout_ms):
def __init__(self, transaction_manager, transaction_timeout_ms):
super(InitProducerIdHandler, self).__init__(transaction_manager)

self.transactional_id = transactional_id
if transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
self.request = InitProducerIdRequest[version](
transactional_id=transactional_id,
transactional_id=self.transactional_id,
transaction_timeout_ms=transaction_timeout_ms)

@property
Expand All @@ -619,10 +649,9 @@ def handle_response(self, response):
self.fatal_error(Errors.KafkaError("Unexpected error in InitProducerIdResponse: %s" % (error())))

class AddPartitionsToTxnHandler(TxnRequestHandler):
def __init__(self, transaction_manager, transactional_id, producer_id, producer_epoch, topic_partitions):
def __init__(self, transaction_manager, topic_partitions):
super(AddPartitionsToTxnHandler, self).__init__(transaction_manager)

self.transactional_id = transactional_id
if transaction_manager._api_version >= (2, 7):
version = 2
elif transaction_manager._api_version >= (2, 0):
Expand All @@ -633,9 +662,9 @@ def __init__(self, transaction_manager, transactional_id, producer_id, producer_
for tp in topic_partitions:
topic_data[tp.topic].append(tp.partition)
self.request = AddPartitionsToTxnRequest[version](
transactional_id=transactional_id,
producer_id=producer_id,
producer_epoch=producer_epoch,
transactional_id=self.transactional_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
topics=list(topic_data.items()))

@property
Expand Down Expand Up @@ -771,20 +800,19 @@ def handle_response(self, response):


class EndTxnHandler(TxnRequestHandler):
def __init__(self, transaction_manager, transactional_id, producer_id, producer_epoch, committed):
def __init__(self, transaction_manager, committed):
super(EndTxnHandler, self).__init__(transaction_manager)

self.transactional_id = transactional_id
if self.transaction_manager._api_version >= (2, 7):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
self.request = EndTxnRequest[version](
transactional_id=transactional_id,
producer_id=producer_id,
producer_epoch=producer_epoch,
transactional_id=self.transactional_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
committed=committed)

@property
Expand All @@ -810,3 +838,141 @@ def handle_response(self, response):
self.fatal_error(error())
else:
self.fatal_error(Errors.KafkaError("Unhandled error in EndTxnResponse: %s" % (error())))


class AddOffsetsToTxnHandler(TxnRequestHandler):
def __init__(self, transaction_manager, consumer_group_id, offsets):
super(AddOffsetsToTxnHandler, self).__init__(transaction_manager)

self.consumer_group_id = consumer_group_id
self.offsets = offsets
if self.transaction_manager._api_version >= (2, 7):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
self.request = AddOffsetsToTxnRequest[version](
transactional_id=self.transactional_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
group_id=consumer_group_id)

@property
def priority(self):
return Priority.ADD_PARTITIONS_OR_OFFSETS

def handle_response(self, response):
error = Errors.for_code(response.error_code)

if error is Errors.NoError:
log.debug("Successfully added partition for consumer group %s to transaction", self.consumer_group_id)

# note the result is not completed until the TxnOffsetCommit returns
for tp, offset in six.iteritems(self.offsets):
self.transaction_manager._pending_txn_offset_commits[tp] = offset
handler = TxnOffsetCommitHandler(self.transaction_manager, self.consumer_group_id,
self.transaction_manager._pending_txn_offset_commits, self._result)
self.transaction_manager._enqueue_request(handler)
self.transaction_manager._transaction_started = True
elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError):
self.transaction_manager._lookup_coordinator('transaction', self.transactional_id)
self.reenqueue()
elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError):
self.reenqueue()
elif error is Errors.InvalidProducerEpochError:
self.fatal_error(error())
elif error is Errors.TransactionalIdAuthorizationFailedError:
self.fatal_error(error())
elif error is Errors.GroupAuthorizationFailedError:
self.abortable_error(Errors.GroupAuthorizationError(self.consumer_group_id))
else:
self.fatal_error(Errors.KafkaError("Unexpected error in AddOffsetsToTxnResponse: %s" % (error())))


class TxnOffsetCommitHandler(TxnRequestHandler):
def __init__(self, transaction_manager, consumer_group_id, offsets, result):
super(TxnOffsetCommitHandler, self).__init__(transaction_manager, result=result)

self.consumer_group_id = consumer_group_id
self.offsets = offsets
self.request = self._build_request()

def _build_request(self):
if self.transaction_manager._api_version >= (2, 1):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0

topic_data = collections.defaultdict(list)
for tp, offset in six.iteritems(self.offsets):
if version >= 2:
partition_data = (tp.partition, offset.offset, offset.leader_epoch, offset.metadata)
else:
partition_data = (tp.partition, offset.offset, offset.metadata)
topic_data[tp.topic].append(partition_data)

return TxnOffsetCommitRequest[version](
transactional_id=self.transactional_id,
group_id=self.consumer_group_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
topics=list(topic_data.items()))

@property
def priority(self):
return Priority.ADD_PARTITIONS_OR_OFFSETS

@property
def coordinator_type(self):
return 'group'

@property
def coordinator_key(self):
return self.consumer_group_id

def handle_response(self, response):
lookup_coordinator = False
retriable_failure = False

errors = {TopicPartition(topic, partition): Errors.for_code(error_code)
for topic, partition_data in response.topics
for partition, error_code in partition_data}

for tp, error in six.iteritems(errors):
if error is Errors.NoError:
log.debug("Successfully added offsets for %s from consumer group %s to transaction.",
tp, self.consumer_group_id)
del self.transaction_manager._pending_txn_offset_commits[tp]
elif error in (errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError, Errors.RequestTimedOutError):
retriable_failure = True
lookup_coordinator = True
elif error is Errors.UnknownTopicOrPartitionError:
retriable_failure = True
elif error is Errors.GroupAuthorizationFailedError:
self.abortable_error(Errors.GroupAuthorizationError(self.consumer_group_id))
return
elif error in (Errors.TransactionalIdAuthorizationFailedError,
Errors.InvalidProducerEpochError,
Errors.UnsupportedForMessageFormatError):
self.fatal_error(error())
return
else:
self.fatal_error(Errors.KafkaError("Unexpected error in TxnOffsetCommitResponse: %s" % (error())))
return

if lookup_coordinator:
self.transaction_manager._lookup_coordinator('group', self.consumer_group_id)

if not retriable_failure:
# all attempted partitions were either successful, or there was a fatal failure.
# either way, we are not retrying, so complete the request.
self.result.done()

# retry the commits which failed with a retriable error.
elif self.transaction_manager._pending_txn_offset_commits:
self.offsets = self.transaction_manager._pending_txn_offset_commits
self.request = self._build_request()
self.reenqueue()
Loading
Loading