Skip to content

Commit

Permalink
fix: num parameter test
Browse files Browse the repository at this point in the history
  • Loading branch information
mali-git committed Aug 10, 2024
1 parent f066b7d commit 1c99be5
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum

import torch

import modalities
Expand Down Expand Up @@ -45,11 +47,18 @@ def test_get_total_number_of_trainable_parameters():

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
world_size = 8
num_gpus_per_node = 4

# Create a mock FSDP model
class MockFSDP:
class ShardingStrategy(Enum):
FULL_SHARD = "FULL_SHARD"
HYBRID_SHARD = "HYBRID_SHARD"

def __init__(self, model):
self.model = model
self.sharding_strategy = self.ShardingStrategy.FULL_SHARD

fsdp_model = MockFSDP(model)

Expand All @@ -61,11 +70,29 @@ def mock_all_reduce(tensor, op):
def mock_cuda(tensor):
return tensor

def mock_world_size():
return world_size

def mock_device_count():
return num_gpus_per_node

def mock_get_local_number_of_trainable_parameters(model: MockFSDP):
return get_local_number_of_trainable_parameters(model.model)
if model.sharding_strategy == MockFSDP.ShardingStrategy.FULL_SHARD:
return get_local_number_of_trainable_parameters(model.model)
elif model.sharding_strategy == MockFSDP.ShardingStrategy.HYBRID_SHARD:
sharding_factor = world_size // num_gpus_per_node
return sharding_factor * get_local_number_of_trainable_parameters(model.model)
else:
raise ValueError(f"Sharding strategy {model.sharding_strategy} not supported.")

modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
torch.distributed.all_reduce = mock_all_reduce
torch.distributed.get_world_size = mock_world_size
torch.cuda.device_count = mock_device_count
torch.Tensor.cuda = mock_cuda

assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params

fsdp_model.sharding_strategy = MockFSDP.ShardingStrategy.HYBRID_SHARD
modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params

0 comments on commit 1c99be5

Please sign in to comment.