From 6d097beccc4e3b0ac806c7d975f8c10d4689de26 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:24:20 -0500 Subject: [PATCH] add _deepspeed_no_cast attribute (#61) --- deepspeed/runtime/engine.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f005f1ab58b6..20df5eaf81bd 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1074,11 +1074,25 @@ def _configure_distributed_model(self, model): if self.fp16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.half) - self.module.half() + # selectively avoid casting specially + # marked parameters to 16-bit + self.module._apply( + lambda t: t.half() if ( + t.is_floating_point() and + not getattr(t, "_deepspeed_no_cast", False) + ) else t + ) elif self.bfloat16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.bfloat16) - self.module.bfloat16() + # selectively avoid casting specially + # marked parameters to 16-bit + self.module._apply( + lambda t: t.bfloat16() if ( + t.is_floating_point() and + not getattr(t, "_deepspeed_no_cast", False) + ) else t + ) else: self.__check_params(self.module, torch.float)