diff --git a/src/modalities/util.py b/src/modalities/util.py index 0cc19fb6..05a2f80f 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -77,6 +77,7 @@ def get_total_number_of_trainable_parameters(model: FSDP) -> Number: dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM) total_num_params = num_params_tensor.item() # For HYBRID sharding, divide by sharding factor to get the correct number of parameters + # TODO: Define constant instead of hardcoding string if model.sharding_strategy.name == "HYBRID_SHARD": # Assumes that CUDA is available and each node has the same number of GPUs # Note: Per default FSDP constructs process groups for the user to shard intra-node and replicate inter-node.