diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f005f1ab58b6..20df5eaf81bd 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1074,11 +1074,25 @@ def _configure_distributed_model(self, model): if self.fp16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.half) - self.module.half() + # selectively avoid casting specially + # marked parameters to 16-bit + self.module._apply( + lambda t: t.half() if ( + t.is_floating_point() and + not getattr(t, "_deepspeed_no_cast", False) + ) else t + ) elif self.bfloat16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.bfloat16) - self.module.bfloat16() + # selectively avoid casting specially + # marked parameters to 16-bit + self.module._apply( + lambda t: t.bfloat16() if ( + t.is_floating_point() and + not getattr(t, "_deepspeed_no_cast", False) + ) else t + ) else: self.__check_params(self.module, torch.float)