From 5212a0f2a585abba1cc2a65f82f7c4fc939c453f Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:36:31 +0200 Subject: [PATCH] Edenzzzz's fix for min_8bit_size functionality in Optimizer base classes (#1286) * fix min_8bit_size invalid bug * Apply same fix to other optimizer base class --------- Co-authored-by: Edenzzzz --- bitsandbytes/optim/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index f1e60e5e7..39fa0e7ff 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -437,7 +437,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: @@ -656,7 +656,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: