diff --git a/tests/test_utils.py b/tests/test_utils.py index caa0fd01..2850f257 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -35,7 +35,7 @@ def test_get_local_number_of_trainable_parameters(): model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) # Calculate the expected number of trainable parameters - expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 # Call the function and check the result assert get_local_number_of_trainable_parameters(model) == expected_params @@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters(): model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) # Calculate the expected number of trainable parameters - expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 + expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67 world_size = 8 num_gpus_per_node = 4