From 09118549d8207526f71357502f163930b88f07d8 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 16 Sep 2024 11:02:20 -0400 Subject: [PATCH] Add PagedAdEMAMix32bit, AdEMAMix32bit --- bitsandbytes/optim/ademamix.py | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/bitsandbytes/optim/ademamix.py b/bitsandbytes/optim/ademamix.py index 5bde9e144..fca9ac7b5 100644 --- a/bitsandbytes/optim/ademamix.py +++ b/bitsandbytes/optim/ademamix.py @@ -252,3 +252,61 @@ def __init__( min_8bit_size=min_8bit_size, is_paged=True, ) + + +class AdEMAMix32bit(Optimizer2State): + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + t_alpha: Optional[int] = None, + t_beta3: Optional[int] = None, + eps: float = 1e-8, + weight_decay: float = 1e-2, + min_8bit_size: int = 4096, + is_paged: bool = False, + ): + super().__init__( + "ademamix", + params=params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + optim_bits=32, + args=None, + min_8bit_size=min_8bit_size, + percentile_clipping=100, + block_wise=True, + is_paged=is_paged, + alpha=alpha, + ) + + +class PagedAdEMAMix32bit(AdEMAMix32bit): + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + t_alpha: Optional[int] = None, + t_beta3: Optional[int] = None, + eps: float = 1e-8, + weight_decay: float = 1e-2, + min_8bit_size: int = 4096, + ): + super().__init__( + params, + lr=lr, + betas=betas, + alpha=alpha, + t_alpha=t_alpha, + t_beta3=t_beta3, + eps=eps, + weight_decay=weight_decay, + min_8bit_size=min_8bit_size, + is_paged=True, + )