Skip to content

Commit

Permalink
chore: add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
mali-git committed Aug 10, 2024
1 parent 1c99be5 commit 7bf0a03
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_total_number_of_trainable_parameters(model: FSDP) -> Number:
dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)
total_num_params = num_params_tensor.item()
# For HYBRID sharding, divide by sharding factor to get the correct number of parameters
# TODO: Define constant instead of hardcoding string
if model.sharding_strategy.name == "HYBRID_SHARD":
# Assumes that CUDA is available and each node has the same number of GPUs
# Note: Per default FSDP constructs process groups for the user to shard intra-node and replicate inter-node.
Expand Down

0 comments on commit 7bf0a03

Please sign in to comment.