diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index f1e60e5e7..39fa0e7ff 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -437,7 +437,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: @@ -656,7 +656,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: