Skip to content

Commit

Permalink
Merge pull request #229 from Modalities/fix/test_utils
Browse files Browse the repository at this point in the history
Fix linting & improve test_utils
  • Loading branch information
le1nux authored Aug 13, 2024
2 parents bba4ecb + b13aeeb commit db18bfb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_get_local_number_of_trainable_parameters():
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67

# Call the function and check the result
assert get_local_number_of_trainable_parameters(model) == expected_params
Expand All @@ -46,7 +46,7 @@ def test_get_total_number_of_trainable_parameters():
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
world_size = 8
num_gpus_per_node = 4

Expand Down

0 comments on commit db18bfb

Please sign in to comment.