From 27ecccac885d371ff2b548e9bea4e5e88cf0c778 Mon Sep 17 00:00:00 2001 From: Zhan Lu <51200935+lausannel@users.noreply.github.com> Date: Fri, 20 Sep 2024 10:52:53 +0800 Subject: [PATCH] Support compute dtype in spmd fsdp (#22) --- torchacc/dist/spmd_fsdp.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchacc/dist/spmd_fsdp.py b/torchacc/dist/spmd_fsdp.py index bd5fd22..a815c60 100644 --- a/torchacc/dist/spmd_fsdp.py +++ b/torchacc/dist/spmd_fsdp.py @@ -49,6 +49,12 @@ def fsdp(self, model: torch.nn.Module, config: Config): transformer_layer_cls=layer_cls, ) + dtype = torch.float32 + if config.compute.fp16: + dtype = torch.float16 + if config.compute.bf16: + dtype = torch.bfloat16 + auto_wrapper_callable = None if config.memory.gc and (config.memory.gc_cls == config.dist.fsdp.wrap_layer_cls): @@ -69,6 +75,7 @@ def auto_wrapper_callable(m, *args, **kwargs): model, mesh, shard_output=self.shard_output_callable, + compute_dtype=dtype, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable) return model