1
+ from enum import Enum
2
+
1
3
import torch
2
4
3
5
import modalities
@@ -44,12 +46,19 @@ def test_get_total_number_of_trainable_parameters():
44
46
model = torch .nn .Sequential (torch .nn .Linear (10 , 5 ), torch .nn .ReLU (), torch .nn .Linear (5 , 2 ))
45
47
46
48
# Calculate the expected number of trainable parameters
47
- expected_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
49
+ expected_params = 10 * 5 + 5 + 5 * 2 + 2 # weights_1 + bias_1 + weights_2 + bias_2 = 67
50
+ world_size = 8
51
+ num_gpus_per_node = 4
48
52
49
53
# Create a mock FSDP model
50
54
class MockFSDP :
55
+ class ShardingStrategy (Enum ):
56
+ FULL_SHARD = "FULL_SHARD"
57
+ HYBRID_SHARD = "HYBRID_SHARD"
58
+
51
59
def __init__ (self , model ):
52
60
self .model = model
61
+ self .sharding_strategy = self .ShardingStrategy .FULL_SHARD
53
62
54
63
fsdp_model = MockFSDP (model )
55
64
@@ -61,11 +70,29 @@ def mock_all_reduce(tensor, op):
61
70
def mock_cuda (tensor ):
62
71
return tensor
63
72
73
+ def mock_world_size ():
74
+ return world_size
75
+
76
+ def mock_device_count ():
77
+ return num_gpus_per_node
78
+
64
79
def mock_get_local_number_of_trainable_parameters (model : MockFSDP ):
65
- return get_local_number_of_trainable_parameters (model .model )
80
+ if model .sharding_strategy == MockFSDP .ShardingStrategy .FULL_SHARD :
81
+ return get_local_number_of_trainable_parameters (model .model )
82
+ elif model .sharding_strategy == MockFSDP .ShardingStrategy .HYBRID_SHARD :
83
+ sharding_factor = world_size // num_gpus_per_node
84
+ return sharding_factor * get_local_number_of_trainable_parameters (model .model )
85
+ else :
86
+ raise ValueError (f"Sharding strategy { model .sharding_strategy } not supported." )
66
87
67
88
modalities .util .get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
68
89
torch .distributed .all_reduce = mock_all_reduce
90
+ torch .distributed .get_world_size = mock_world_size
91
+ torch .cuda .device_count = mock_device_count
69
92
torch .Tensor .cuda = mock_cuda
70
93
71
94
assert get_total_number_of_trainable_parameters (fsdp_model ) == expected_params
95
+
96
+ fsdp_model .sharding_strategy = MockFSDP .ShardingStrategy .HYBRID_SHARD
97
+ modalities .util .get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
98
+ assert get_total_number_of_trainable_parameters (fsdp_model ) == expected_params
0 commit comments