From ee2dd2a15a03c41ea2754d05d037089c4f9c1a5e Mon Sep 17 00:00:00 2001 From: Mehdi Ali <33023925+mali-git@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:20:36 +0200 Subject: [PATCH] Update tests/test_utils.py Co-authored-by: Felix Stollenwerk --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index bbeea477..caa0fd01 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters(): model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) # Calculate the expected number of trainable parameters - expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 world_size = 8 num_gpus_per_node = 4