diff --git a/bitsandbytes/optim/ademamix.py b/bitsandbytes/optim/ademamix.py index 5bde9e144..fca9ac7b5 100644 --- a/bitsandbytes/optim/ademamix.py +++ b/bitsandbytes/optim/ademamix.py @@ -252,3 +252,61 @@ def __init__( min_8bit_size=min_8bit_size, is_paged=True, ) + + +class AdEMAMix32bit(Optimizer2State): + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + t_alpha: Optional[int] = None, + t_beta3: Optional[int] = None, + eps: float = 1e-8, + weight_decay: float = 1e-2, + min_8bit_size: int = 4096, + is_paged: bool = False, + ): + super().__init__( + "ademamix", + params=params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + optim_bits=32, + args=None, + min_8bit_size=min_8bit_size, + percentile_clipping=100, + block_wise=True, + is_paged=is_paged, + alpha=alpha, + ) + + +class PagedAdEMAMix32bit(AdEMAMix32bit): + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + t_alpha: Optional[int] = None, + t_beta3: Optional[int] = None, + eps: float = 1e-8, + weight_decay: float = 1e-2, + min_8bit_size: int = 4096, + ): + super().__init__( + params, + lr=lr, + betas=betas, + alpha=alpha, + t_alpha=t_alpha, + t_beta3=t_beta3, + eps=eps, + weight_decay=weight_decay, + min_8bit_size=min_8bit_size, + is_paged=True, + )