11
11
from kafka .vendor import six
12
12
13
13
from kafka .client_async import KafkaClient
14
+ from kafka .cluster import ClusterMetadata
14
15
import kafka .errors as Errors
15
16
from kafka .protocol .broker_api_versions import BROKER_API_VERSIONS
16
17
from kafka .producer .kafka import KafkaProducer
17
18
from kafka .protocol .produce import ProduceRequest
18
19
from kafka .producer .record_accumulator import RecordAccumulator , ProducerBatch
19
20
from kafka .producer .sender import Sender
20
- from kafka .producer .transaction_state import TransactionState
21
+ from kafka .producer .transaction_manager import TransactionManager
21
22
from kafka .record .memory_records import MemoryRecordsBuilder
22
23
from kafka .structs import TopicPartition
23
24
@@ -42,6 +43,16 @@ def producer_batch(topic='foo', partition=0, magic=2):
42
43
return batch
43
44
44
45
46
+ @pytest .fixture
47
+ def transaction_manager ():
48
+ return TransactionManager (
49
+ transactional_id = None ,
50
+ transaction_timeout_ms = 60000 ,
51
+ retry_backoff_ms = 100 ,
52
+ api_version = (2 , 1 ),
53
+ metadata = ClusterMetadata ())
54
+
55
+
45
56
@pytest .mark .parametrize (("api_version" , "produce_version" ), [
46
57
((2 , 1 ), 7 ),
47
58
((0 , 10 , 0 ), 2 ),
@@ -85,16 +96,16 @@ def test_complete_batch_success(sender):
85
96
assert batch .produce_future .value == (0 , 123 , 456 )
86
97
87
98
88
- def test_complete_batch_transaction (sender ):
89
- sender ._transaction_state = TransactionState ()
99
+ def test_complete_batch_transaction (sender , transaction_manager ):
100
+ sender ._transaction_manager = transaction_manager
90
101
batch = producer_batch ()
91
- assert sender ._transaction_state .sequence_number (batch .topic_partition ) == 0
92
- assert sender ._transaction_state .producer_id_and_epoch .producer_id == batch .producer_id
102
+ assert sender ._transaction_manager .sequence_number (batch .topic_partition ) == 0
103
+ assert sender ._transaction_manager .producer_id_and_epoch .producer_id == batch .producer_id
93
104
94
105
# No error, base_offset 0
95
106
sender ._complete_batch (batch , None , 0 )
96
107
assert batch .is_done
97
- assert sender ._transaction_state .sequence_number (batch .topic_partition ) == batch .record_count
108
+ assert sender ._transaction_manager .sequence_number (batch .topic_partition ) == batch .record_count
98
109
99
110
100
111
@pytest .mark .parametrize (("error" , "refresh_metadata" ), [
@@ -164,8 +175,8 @@ def test_complete_batch_retry(sender, accumulator, mocker, error, retry):
164
175
assert isinstance (batch .produce_future .exception , error )
165
176
166
177
167
- def test_complete_batch_producer_id_changed_no_retry (sender , accumulator , mocker ):
168
- sender ._transaction_state = TransactionState ()
178
+ def test_complete_batch_producer_id_changed_no_retry (sender , accumulator , transaction_manager , mocker ):
179
+ sender ._transaction_manager = transaction_manager
169
180
sender .config ['retries' ] = 1
170
181
mocker .spy (sender , '_fail_batch' )
171
182
mocker .patch .object (accumulator , 'reenqueue' )
@@ -175,21 +186,21 @@ def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker
175
186
assert not batch .is_done
176
187
accumulator .reenqueue .assert_called_with (batch )
177
188
batch .records ._producer_id = 123 # simulate different producer_id
178
- assert batch .producer_id != sender ._transaction_state .producer_id_and_epoch .producer_id
189
+ assert batch .producer_id != sender ._transaction_manager .producer_id_and_epoch .producer_id
179
190
sender ._complete_batch (batch , error , - 1 )
180
191
assert batch .is_done
181
192
assert isinstance (batch .produce_future .exception , error )
182
193
183
194
184
- def test_fail_batch (sender , accumulator , mocker ):
185
- sender ._transaction_state = TransactionState ()
186
- mocker .patch .object (TransactionState , 'reset_producer_id' )
195
+ def test_fail_batch (sender , accumulator , transaction_manager , mocker ):
196
+ sender ._transaction_manager = transaction_manager
197
+ mocker .patch .object (TransactionManager , 'reset_producer_id' )
187
198
batch = producer_batch ()
188
199
mocker .patch .object (batch , 'done' )
189
- assert sender ._transaction_state .producer_id_and_epoch .producer_id == batch .producer_id
200
+ assert sender ._transaction_manager .producer_id_and_epoch .producer_id == batch .producer_id
190
201
error = Exception ('error' )
191
202
sender ._fail_batch (batch , base_offset = 0 , timestamp_ms = None , exception = error , log_start_offset = None )
192
- sender ._transaction_state .reset_producer_id .assert_called_once ()
203
+ sender ._transaction_manager .reset_producer_id .assert_called_once ()
193
204
batch .done .assert_called_with (base_offset = 0 , timestamp_ms = None , exception = error , log_start_offset = None )
194
205
195
206
0 commit comments