Skip to content

Commit

Permalink
Add PagedAdEMAMix32bit, AdEMAMix32bit
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Sep 16, 2024
1 parent d8c4b39 commit 0911854
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions bitsandbytes/optim/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,61 @@ def __init__(
min_8bit_size=min_8bit_size,
is_paged=True,
)


class AdEMAMix32bit(Optimizer2State):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
"ademamix",
params=params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
optim_bits=32,
args=None,
min_8bit_size=min_8bit_size,
percentile_clipping=100,
block_wise=True,
is_paged=is_paged,
alpha=alpha,
)


class PagedAdEMAMix32bit(AdEMAMix32bit):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
min_8bit_size=min_8bit_size,
is_paged=True,
)

0 comments on commit 0911854

Please sign in to comment.