Skip to content

Commit bba4ecb

Browse files
authored
Merge pull request #216 from Modalities/fix_num_params_hybrid_sharding
fix: Computation of Number of Parameters in HYBRID Sharding
2 parents 10c6022 + ee2dd2a commit bba4ecb

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/modalities/util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def get_total_number_of_trainable_parameters(model: FSDP) -> Number:
7676
num_params_tensor = torch.tensor(num_params).cuda()
7777
dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)
7878
total_num_params = num_params_tensor.item()
79+
# For HYBRID sharding, divide by sharding factor to get the correct number of parameters
80+
# TODO: Define constant instead of hardcoding string
81+
if model.sharding_strategy.name == "HYBRID_SHARD":
82+
# Assumes that CUDA is available and each node has the same number of GPUs
83+
# Note: Per default FSDP constructs process groups for the user to shard intra-node and replicate inter-node.
84+
# However, users can also provide their own sharding process groups (currently not supported in Modalities)
85+
# which would require to adapt the code.
86+
sharding_factor_hybrid_sharding = dist.get_world_size() // torch.cuda.device_count()
87+
total_num_params = total_num_params // sharding_factor_hybrid_sharding
88+
7989
return total_num_params
8090

8191

tests/test_utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from enum import Enum
2+
13
import torch
24

35
import modalities
@@ -44,12 +46,19 @@ def test_get_total_number_of_trainable_parameters():
4446
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))
4547

4648
# Calculate the expected number of trainable parameters
47-
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
49+
expected_params = 10*5 + 5 + 5*2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
50+
world_size = 8
51+
num_gpus_per_node = 4
4852

4953
# Create a mock FSDP model
5054
class MockFSDP:
55+
class ShardingStrategy(Enum):
56+
FULL_SHARD = "FULL_SHARD"
57+
HYBRID_SHARD = "HYBRID_SHARD"
58+
5159
def __init__(self, model):
5260
self.model = model
61+
self.sharding_strategy = self.ShardingStrategy.FULL_SHARD
5362

5463
fsdp_model = MockFSDP(model)
5564

@@ -61,11 +70,29 @@ def mock_all_reduce(tensor, op):
6170
def mock_cuda(tensor):
6271
return tensor
6372

73+
def mock_world_size():
74+
return world_size
75+
76+
def mock_device_count():
77+
return num_gpus_per_node
78+
6479
def mock_get_local_number_of_trainable_parameters(model: MockFSDP):
65-
return get_local_number_of_trainable_parameters(model.model)
80+
if model.sharding_strategy == MockFSDP.ShardingStrategy.FULL_SHARD:
81+
return get_local_number_of_trainable_parameters(model.model)
82+
elif model.sharding_strategy == MockFSDP.ShardingStrategy.HYBRID_SHARD:
83+
sharding_factor = world_size // num_gpus_per_node
84+
return sharding_factor * get_local_number_of_trainable_parameters(model.model)
85+
else:
86+
raise ValueError(f"Sharding strategy {model.sharding_strategy} not supported.")
6687

6788
modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
6889
torch.distributed.all_reduce = mock_all_reduce
90+
torch.distributed.get_world_size = mock_world_size
91+
torch.cuda.device_count = mock_device_count
6992
torch.Tensor.cuda = mock_cuda
7093

7194
assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params
95+
96+
fsdp_model.sharding_strategy = MockFSDP.ShardingStrategy.HYBRID_SHARD
97+
modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
98+
assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params

0 commit comments

Comments
 (0)