Skip to content

Commit ee2dd2a

Browse files
mali-gitflxst
andauthored
Update tests/test_utils.py
Co-authored-by: Felix Stollenwerk <[email protected]>
1 parent 7bf0a03 commit ee2dd2a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters():
4646
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))
4747

4848
# Calculate the expected number of trainable parameters
49-
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
49+
expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
5050
world_size = 8
5151
num_gpus_per_node = 4
5252

0 commit comments

Comments
 (0)