@@ -287,21 +287,26 @@ def test_simple_producer(self):
287
287
producer = SimpleProducer (self .client )
288
288
resp = producer .send_messages (self .topic , "one" , "two" )
289
289
290
- # Will go to partition 0
290
+ partition_for_first_batch = resp [0 ].partition
291
+
291
292
self .assertEquals (len (resp ), 1 )
292
293
self .assertEquals (resp [0 ].error , 0 )
293
294
self .assertEquals (resp [0 ].offset , 0 ) # offset of first msg
294
295
295
- # Will go to partition 1
296
+ # ensure this partition is different from the first partition
296
297
resp = producer .send_messages (self .topic , "three" )
298
+ partition_for_second_batch = resp [0 ].partition
299
+ self .assertNotEquals (partition_for_first_batch , partition_for_second_batch )
300
+
297
301
self .assertEquals (len (resp ), 1 )
298
302
self .assertEquals (resp [0 ].error , 0 )
299
303
self .assertEquals (resp [0 ].offset , 0 ) # offset of first msg
300
304
301
- fetch1 = FetchRequest (self .topic , 0 , 0 , 1024 )
302
- fetch2 = FetchRequest (self .topic , 1 , 0 , 1024 )
303
- fetch_resp1 , fetch_resp2 = self .client .send_fetch_request ([fetch1 ,
304
- fetch2 ])
305
+ fetch_requests = (
306
+ FetchRequest (self .topic , partition_for_first_batch , 0 , 1024 ),
307
+ FetchRequest (self .topic , partition_for_second_batch , 0 , 1024 ),
308
+ )
309
+ fetch_resp1 , fetch_resp2 = self .client .send_fetch_request (fetch_requests )
305
310
self .assertEquals (fetch_resp1 .error , 0 )
306
311
self .assertEquals (fetch_resp1 .highwaterMark , 2 )
307
312
messages = list (fetch_resp1 .messages )
@@ -314,11 +319,12 @@ def test_simple_producer(self):
314
319
self .assertEquals (len (messages ), 1 )
315
320
self .assertEquals (messages [0 ].message .value , "three" )
316
321
317
- # Will go to partition 0
322
+ # Will go to same partition as first batch
318
323
resp = producer .send_messages (self .topic , "four" , "five" )
319
324
self .assertEquals (len (resp ), 1 )
320
325
self .assertEquals (resp [0 ].error , 0 )
321
326
self .assertEquals (resp [0 ].offset , 2 ) # offset of first msg
327
+ self .assertEquals (resp [0 ].partition , partition_for_first_batch )
322
328
323
329
producer .stop ()
324
330
@@ -396,14 +402,25 @@ def test_acks_none(self):
396
402
resp = producer .send_messages (self .topic , "one" )
397
403
self .assertEquals (len (resp ), 0 )
398
404
399
- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
400
- fetch_resp = self .client .send_fetch_request ([fetch ])
405
+ # fetch from both partitions
406
+ fetch_requests = (
407
+ FetchRequest (self .topic , 0 , 0 , 1024 ),
408
+ FetchRequest (self .topic , 1 , 0 , 1024 ),
409
+ )
410
+ fetch_resps = self .client .send_fetch_request (fetch_requests )
401
411
402
- self .assertEquals (fetch_resp [0 ].error , 0 )
403
- self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
404
- self .assertEquals (fetch_resp [0 ].partition , 0 )
412
+ # determine which partition was selected (due to random round-robin)
413
+ published_to_resp = max (fetch_resps , key = lambda x : x .highwaterMark )
414
+ not_published_to_resp = min (fetch_resps , key = lambda x : x .highwaterMark )
415
+ self .assertNotEquals (published_to_resp .partition , not_published_to_resp .partition )
405
416
406
- messages = list (fetch_resp [0 ].messages )
417
+ self .assertEquals (published_to_resp .error , 0 )
418
+ self .assertEquals (published_to_resp .highwaterMark , 1 )
419
+
420
+ self .assertEquals (not_published_to_resp .error , 0 )
421
+ self .assertEquals (not_published_to_resp .highwaterMark , 0 )
422
+
423
+ messages = list (published_to_resp .messages )
407
424
self .assertEquals (len (messages ), 1 )
408
425
self .assertEquals (messages [0 ].message .value , "one" )
409
426
@@ -415,12 +432,14 @@ def test_acks_local_write(self):
415
432
resp = producer .send_messages (self .topic , "one" )
416
433
self .assertEquals (len (resp ), 1 )
417
434
418
- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
435
+ partition = resp [0 ].partition
436
+
437
+ fetch = FetchRequest (self .topic , partition , 0 , 1024 )
419
438
fetch_resp = self .client .send_fetch_request ([fetch ])
420
439
421
440
self .assertEquals (fetch_resp [0 ].error , 0 )
422
441
self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
423
- self .assertEquals (fetch_resp [0 ].partition , 0 )
442
+ self .assertEquals (fetch_resp [0 ].partition , partition )
424
443
425
444
messages = list (fetch_resp [0 ].messages )
426
445
self .assertEquals (len (messages ), 1 )
@@ -435,12 +454,14 @@ def test_acks_cluster_commit(self):
435
454
resp = producer .send_messages (self .topic , "one" )
436
455
self .assertEquals (len (resp ), 1 )
437
456
438
- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
457
+ partition = resp [0 ].partition
458
+
459
+ fetch = FetchRequest (self .topic , partition , 0 , 1024 )
439
460
fetch_resp = self .client .send_fetch_request ([fetch ])
440
461
441
462
self .assertEquals (fetch_resp [0 ].error , 0 )
442
463
self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
443
- self .assertEquals (fetch_resp [0 ].partition , 0 )
464
+ self .assertEquals (fetch_resp [0 ].partition , partition )
444
465
445
466
messages = list (fetch_resp [0 ].messages )
446
467
self .assertEquals (len (messages ), 1 )
@@ -456,17 +477,31 @@ def test_async_simple_producer(self):
456
477
# Give it some time
457
478
time .sleep (2 )
458
479
459
- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
460
- fetch_resp = self .client .send_fetch_request ([fetch ])
480
+ # fetch from both partitions
481
+ fetch_requests = (
482
+ FetchRequest (self .topic , 0 , 0 , 1024 ),
483
+ FetchRequest (self .topic , 1 , 0 , 1024 ),
484
+ )
485
+ fetch_resps = self .client .send_fetch_request (fetch_requests )
461
486
462
- self .assertEquals (fetch_resp [0 ].error , 0 )
463
- self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
464
- self .assertEquals (fetch_resp [0 ].partition , 0 )
487
+ # determine which partition was selected (due to random round-robin)
488
+ published_to_resp = max (fetch_resps , key = lambda x : x .highwaterMark )
489
+ not_published_to_resp = min (fetch_resps , key = lambda x : x .highwaterMark )
490
+ self .assertNotEquals (published_to_resp .partition , not_published_to_resp .partition )
465
491
466
- messages = list (fetch_resp [0 ].messages )
492
+ self .assertEquals (published_to_resp .error , 0 )
493
+ self .assertEquals (published_to_resp .highwaterMark , 1 )
494
+
495
+ self .assertEquals (not_published_to_resp .error , 0 )
496
+ self .assertEquals (not_published_to_resp .highwaterMark , 0 )
497
+
498
+ messages = list (published_to_resp .messages )
467
499
self .assertEquals (len (messages ), 1 )
468
500
self .assertEquals (messages [0 ].message .value , "one" )
469
501
502
+ messages = list (not_published_to_resp .messages )
503
+ self .assertEquals (len (messages ), 0 )
504
+
470
505
producer .stop ()
471
506
472
507
def test_async_keyed_producer (self ):
0 commit comments