Skip to content

Commit

Permalink
Merge pull request #216 from Modalities/fix_num_params_hybrid_sharding
Browse files Browse the repository at this point in the history
fix: Computation of  Number of Parameters in HYBRID Sharding
  • Loading branch information
mali-git authored Aug 13, 2024
2 parents 10c6022 + ee2dd2a commit bba4ecb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def get_total_number_of_trainable_parameters(model: FSDP) -> Number:
num_params_tensor = torch.tensor(num_params).cuda()
dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)
total_num_params = num_params_tensor.item()
# For HYBRID sharding, divide by sharding factor to get the correct number of parameters
# TODO: Define constant instead of hardcoding string
if model.sharding_strategy.name == "HYBRID_SHARD":
# Assumes that CUDA is available and each node has the same number of GPUs
# Note: Per default FSDP constructs process groups for the user to shard intra-node and replicate inter-node.
# However, users can also provide their own sharding process groups (currently not supported in Modalities)
# which would require to adapt the code.
sharding_factor_hybrid_sharding = dist.get_world_size() // torch.cuda.device_count()
total_num_params = total_num_params // sharding_factor_hybrid_sharding

return total_num_params


Expand Down
31 changes: 29 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum

import torch

import modalities
Expand Down Expand Up @@ -44,12 +46,19 @@ def test_get_total_number_of_trainable_parameters():
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
world_size = 8
num_gpus_per_node = 4

# Create a mock FSDP model
class MockFSDP:
class ShardingStrategy(Enum):
FULL_SHARD = "FULL_SHARD"
HYBRID_SHARD = "HYBRID_SHARD"

def __init__(self, model):
self.model = model
self.sharding_strategy = self.ShardingStrategy.FULL_SHARD

fsdp_model = MockFSDP(model)

Expand All @@ -61,11 +70,29 @@ def mock_all_reduce(tensor, op):
def mock_cuda(tensor):
return tensor

def mock_world_size():
return world_size

def mock_device_count():
return num_gpus_per_node

def mock_get_local_number_of_trainable_parameters(model: MockFSDP):
return get_local_number_of_trainable_parameters(model.model)
if model.sharding_strategy == MockFSDP.ShardingStrategy.FULL_SHARD:
return get_local_number_of_trainable_parameters(model.model)
elif model.sharding_strategy == MockFSDP.ShardingStrategy.HYBRID_SHARD:
sharding_factor = world_size // num_gpus_per_node
return sharding_factor * get_local_number_of_trainable_parameters(model.model)
else:
raise ValueError(f"Sharding strategy {model.sharding_strategy} not supported.")

modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
torch.distributed.all_reduce = mock_all_reduce
torch.distributed.get_world_size = mock_world_size
torch.cuda.device_count = mock_device_count
torch.Tensor.cuda = mock_cuda

assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params

fsdp_model.sharding_strategy = MockFSDP.ShardingStrategy.HYBRID_SHARD
modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params

0 comments on commit bba4ecb

Please sign in to comment.