Skip to content

Commit

Permalink
Support compute dtype in spmd fsdp (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
lausannel authored Sep 20, 2024
1 parent de882ac commit 27eccca
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchacc/dist/spmd_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def fsdp(self, model: torch.nn.Module, config: Config):
transformer_layer_cls=layer_cls,
)

dtype = torch.float32
if config.compute.fp16:
dtype = torch.float16
if config.compute.bf16:
dtype = torch.bfloat16

auto_wrapper_callable = None
if config.memory.gc and (config.memory.gc_cls
== config.dist.fsdp.wrap_layer_cls):
Expand All @@ -69,6 +75,7 @@ def auto_wrapper_callable(m, *args, **kwargs):
model,
mesh,
shard_output=self.shard_output_callable,
compute_dtype=dtype,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)
return model
Expand Down

0 comments on commit 27eccca

Please sign in to comment.