Skip to content

Commit

Permalink
Update tests/test_utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Felix Stollenwerk <[email protected]>
  • Loading branch information
mali-git and flxst authored Aug 13, 2024
1 parent 7bf0a03 commit ee2dd2a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters():
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
world_size = 8
num_gpus_per_node = 4

Expand Down

0 comments on commit ee2dd2a

Please sign in to comment.