From b3f76c79968c5b162cb5521f7e7dfa0a904f042c Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Tue, 13 Aug 2024 20:15:23 +0200 Subject: [PATCH 1/2] fix: linting in test_utils --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index caa0fd01..fa4b657b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters(): model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) # Calculate the expected number of trainable parameters - expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 + expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 world_size = 8 num_gpus_per_node = 4 From b13aeebe34cf7ffab187213feac651d6bf790cd4 Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Tue, 13 Aug 2024 20:18:58 +0200 Subject: [PATCH 2/2] chore: manual computation of expected number of parameters in test_utils --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index fa4b657b..2850f257 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -35,7 +35,7 @@ def test_get_local_number_of_trainable_parameters(): model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) # Calculate the expected number of trainable parameters - expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 # Call the function and check the result assert get_local_number_of_trainable_parameters(model) == expected_params