Skip to content

Commit 4cad6f1

Browse files
committed
test_sender updates for transaction manager
1 parent 5b205c3 commit 4cad6f1

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

test/test_sender.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from kafka.vendor import six
1212

1313
from kafka.client_async import KafkaClient
14+
from kafka.cluster import ClusterMetadata
1415
import kafka.errors as Errors
1516
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
1617
from kafka.producer.kafka import KafkaProducer
1718
from kafka.protocol.produce import ProduceRequest
1819
from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch
1920
from kafka.producer.sender import Sender
20-
from kafka.producer.transaction_state import TransactionState
21+
from kafka.producer.transaction_manager import TransactionManager
2122
from kafka.record.memory_records import MemoryRecordsBuilder
2223
from kafka.structs import TopicPartition
2324

@@ -42,6 +43,16 @@ def producer_batch(topic='foo', partition=0, magic=2):
4243
return batch
4344

4445

46+
@pytest.fixture
47+
def transaction_manager():
48+
return TransactionManager(
49+
transactional_id=None,
50+
transaction_timeout_ms=60000,
51+
retry_backoff_ms=100,
52+
api_version=(2, 1),
53+
metadata=ClusterMetadata())
54+
55+
4556
@pytest.mark.parametrize(("api_version", "produce_version"), [
4657
((2, 1), 7),
4758
((0, 10, 0), 2),
@@ -85,16 +96,16 @@ def test_complete_batch_success(sender):
8596
assert batch.produce_future.value == (0, 123, 456)
8697

8798

88-
def test_complete_batch_transaction(sender):
89-
sender._transaction_state = TransactionState()
99+
def test_complete_batch_transaction(sender, transaction_manager):
100+
sender._transaction_manager = transaction_manager
90101
batch = producer_batch()
91-
assert sender._transaction_state.sequence_number(batch.topic_partition) == 0
92-
assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id
102+
assert sender._transaction_manager.sequence_number(batch.topic_partition) == 0
103+
assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id
93104

94105
# No error, base_offset 0
95106
sender._complete_batch(batch, None, 0)
96107
assert batch.is_done
97-
assert sender._transaction_state.sequence_number(batch.topic_partition) == batch.record_count
108+
assert sender._transaction_manager.sequence_number(batch.topic_partition) == batch.record_count
98109

99110

100111
@pytest.mark.parametrize(("error", "refresh_metadata"), [
@@ -164,8 +175,8 @@ def test_complete_batch_retry(sender, accumulator, mocker, error, retry):
164175
assert isinstance(batch.produce_future.exception, error)
165176

166177

167-
def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker):
168-
sender._transaction_state = TransactionState()
178+
def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, transaction_manager, mocker):
179+
sender._transaction_manager = transaction_manager
169180
sender.config['retries'] = 1
170181
mocker.spy(sender, '_fail_batch')
171182
mocker.patch.object(accumulator, 'reenqueue')
@@ -175,21 +186,21 @@ def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker
175186
assert not batch.is_done
176187
accumulator.reenqueue.assert_called_with(batch)
177188
batch.records._producer_id = 123 # simulate different producer_id
178-
assert batch.producer_id != sender._transaction_state.producer_id_and_epoch.producer_id
189+
assert batch.producer_id != sender._transaction_manager.producer_id_and_epoch.producer_id
179190
sender._complete_batch(batch, error, -1)
180191
assert batch.is_done
181192
assert isinstance(batch.produce_future.exception, error)
182193

183194

184-
def test_fail_batch(sender, accumulator, mocker):
185-
sender._transaction_state = TransactionState()
186-
mocker.patch.object(TransactionState, 'reset_producer_id')
195+
def test_fail_batch(sender, accumulator, transaction_manager, mocker):
196+
sender._transaction_manager = transaction_manager
197+
mocker.patch.object(TransactionManager, 'reset_producer_id')
187198
batch = producer_batch()
188199
mocker.patch.object(batch, 'done')
189-
assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id
200+
assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id
190201
error = Exception('error')
191202
sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None)
192-
sender._transaction_state.reset_producer_id.assert_called_once()
203+
sender._transaction_manager.reset_producer_id.assert_called_once()
193204
batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None)
194205

195206

0 commit comments

Comments
 (0)