Skip to content

Commit

Permalink
Edenzzzz's fix for min_8bit_size functionality in Optimizer base clas…
Browse files Browse the repository at this point in the history
…ses (#1286)

* fix min_8bit_size invalid bug

* Apply same fix to other optimizer base class

---------

Co-authored-by: Edenzzzz <[email protected]>
  • Loading branch information
Titus-von-Koeller and Edenzzzz committed Jul 22, 2024
1 parent 0bdd57c commit 5212a0f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def init_state(self, group, p, gindex, pindex):
state = self.state[p]
state["step"] = 0

if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
Expand Down Expand Up @@ -656,7 +656,7 @@ def init_state(self, group, p, gindex, pindex):
state = self.state[p]
state["step"] = 0

if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
Expand Down

0 comments on commit 5212a0f

Please sign in to comment.