Skip to content

Commit a58f40f

Browse files
petergtzsvekars
andauthored
Fix batch size calculation in dist_tuto (#2754)
Batch size must be an int, not a float. This change fixes it, basically doing the same as in https://github.com/seba-1511/dist_tuto.pth/blob/a552567061a9985cdcfe72ecb9b47e4630d6a7fe/train_dist.py#L85. Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent cdbb559 commit a58f40f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/dist_tuto.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ the following few lines:
327327
transforms.Normalize((0.1307,), (0.3081,))
328328
]))
329329
size = dist.get_world_size()
330-
bsz = 128 / float(size)
330+
bsz = 128 // size
331331
partition_sizes = [1.0 / size for _ in range(size)]
332332
partition = DataPartitioner(dataset, partition_sizes)
333333
partition = partition.use(dist.get_rank())

0 commit comments

Comments
 (0)