Skip to content

Commit

Permalink
Fix batch size calculation in dist_tuto (#2754)
Browse files Browse the repository at this point in the history
Batch size must be an int, not a float. This change fixes it, basically doing the same as in https://github.com/seba-1511/dist_tuto.pth/blob/a552567061a9985cdcfe72ecb9b47e4630d6a7fe/train_dist.py#L85.

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
petergtz and svekars authored May 31, 2024
1 parent cdbb559 commit a58f40f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion intermediate_source/dist_tuto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ the following few lines:
transforms.Normalize((0.1307,), (0.3081,))
]))
size = dist.get_world_size()
bsz = 128 / float(size)
bsz = 128 // size
partition_sizes = [1.0 / size for _ in range(size)]
partition = DataPartitioner(dataset, partition_sizes)
partition = partition.use(dist.get_rank())
Expand Down

0 comments on commit a58f40f

Please sign in to comment.