Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prototype] Add param-to-lr interface to distributed Shampoo #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 166 additions & 22 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ class DistributedShampoo(torch.optim.Optimizer):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
experimental_param_to_lr (Optional[Callable[[Tensor], float]]): Optional mapping between Param and learning rate.
If set, this map needs to cover all parameters in param_groups.
This setting supersedes learning rate of each parameter group.
(Default: None)

"""

Expand Down Expand Up @@ -326,6 +330,7 @@ def __init__(
precision_config: Optional[PrecisionConfig] = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
experimental_param_to_lr: Optional[Callable[[torch.Tensor], float]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not want to expose this directly to the user, but create a flag that merges parameter groups that have different lr, betas, beta3, epsilon, momentum, dampening, or weight_decay, but share the same fields everywhere else.

) -> None:
# Hyperparameter checks.
if not lr >= 0.0:
Expand Down Expand Up @@ -474,6 +479,7 @@ def __init__(
self._shampoo_pt2_compile_config: Optional[ShampooPT2CompileConfig] = (
shampoo_pt2_compile_config
)
self._experimental_param_to_lr = experimental_param_to_lr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to handle support for this properly, we will need to create a function that constructs this mapping automatically from each parameter (within each parameter group) to its learning rate and modify the parameter groups defined typically by torch.optim.Optimizer (see https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer).

This may need to be moved prior to the super().__init__() function above, so we can call this on the parameter groups.


# Initialize dictionary containing lists of .
self._per_group_state_lists: List[Dict[str, Any]] = [
Expand Down Expand Up @@ -1142,6 +1148,114 @@ def _per_group_step_impl(
masked_blocked_search_directions=masked_blocked_search_directions
)

@torch.no_grad()
def _per_group_step_experimental_lrs(
self,
state_lists: Dict[str, Any],
step: torch.Tensor,
neg_lrs: List[torch.Tensor],
beta1: float,
beta3: float,
weight_decay: float,
momentum_param: float,
dampening: float,
grafting_config_not_none: bool,
compute_root_inverse: bool,
use_decoupled_weight_decay: bool,
use_bias_correction: bool,
use_grafting_method: bool,
use_nesterov: bool,
) -> None:
# Incorporate L2-regularization or (coupled) weight decay if enabled.
# G <- G + lr * weight_decay * W
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to self that we have a typo here; shouldn't have lr * weight_decay, just weight_decay.

cc: @tsunghsienlee

self._add_l2_regularization(
state_lists,
weight_decay,
use_decoupled_weight_decay,
)

with DequantizePreconditionersContext(
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
), (
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
):
# Update Shampoo and grafting preconditioners / factor matrices.
# Example for AdaGrad accumulation:
# L <- L + G * G^T
# R <- R + G^T * G
# V <- V + G^2 (element-wise)
# (and similar)
self._update_preconditioners(
state_lists,
step,
grafting_config_not_none,
)

# Compute matrix root inverse.
# L_inv <- L ** (-1/4)
# R_inv <- R ** (-1/4)
# (and similar)
self._compute_root_inverse(state_lists, compute_root_inverse)

# Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0.
# Note that we use two beta factors here akin to Lion.
# G_bar <- beta3 * G_tilde + (1 - beta3) * G
# G_tilde <- beta1 * G_tilde + (1 - beta1) * G
masked_filtered_grad_list = self._compute_filtered_grad_list(
state_lists,
step,
beta1,
beta3,
use_bias_correction,
)

# Precondition and graft filtered gradients.
# PT2 compile is currently disabled for preconditioning and grafting.
# TODO: Resolve preconditioning and grafting PT2 NEX issue and enable them.
#
# P_shampoo <- L_inv * G_bar * R_inv (and similar)
# P_grafting <- G_bar / (sqrt(V) + epsilon)
# P <- P_grafting if step < start_preconditioning_step
# P <- ||P_grafting|| / ||P_shampoo|| * P_shampoo otherwise
masked_blocked_search_directions = self._precondition_and_grafting(
state_lists,
masked_filtered_grad_list,
use_grafting_method,
grafting_config_not_none,
)

# Incorporate decoupled weight decay into search direction if enabled.
# P <- P + weight_decay * W
self._apply_decoupled_weight_decay(
state_lists,
masked_blocked_search_directions,
weight_decay,
use_decoupled_weight_decay,
)

# Update momentum optimizer state and use momentum / Nesterov if enabled.
# M <- momentum_param * M + (1 - dampening) * P
# P <- (1 - dampening) * P + momentum_param * M if use_nesterov
# P <- M otherwise.
self._update_momentum(
state_lists,
masked_blocked_search_directions,
momentum_param,
dampening,
use_nesterov,
)

# Updates parameters in distributed fashion.
# If DDP, executes AllGather communication to ensure all parameters are updated after local updates.
torch._foreach_mul_(masked_blocked_search_directions, neg_lrs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is exactly the same as the code for _per_group_step_impl with only the learning rate change, but this experimental path is just requiring us to pass in a list of tensors.

state_lists[DISTRIBUTOR].update_params(
masked_blocked_search_directions=masked_blocked_search_directions
)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Expand Down Expand Up @@ -1173,12 +1287,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

# Iterate group step counter and define Python scalar step.
step = state_lists[STEP].add_(1)
# NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation;
# Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu
# tensor impl once PT2 supports cpu 0D tensor properly.
lr = torch.tensor(group[LR], dtype=torch.float).to(
self._device, non_blocking=True
)
beta1 = group[BETAS][0]
beta3 = group[BETA3]
weight_decay = group[WEIGHT_DECAY]
Expand All @@ -1200,22 +1308,58 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
)
use_nesterov = group[USE_NESTEROV]

self._per_group_step(
state_lists,
step,
lr,
beta1,
beta3,
weight_decay,
momentum_param,
dampening,
grafting_config_not_none,
compute_root_inverse,
use_decoupled_weight_decay,
use_bias_correction,
use_grafting_method,
use_nesterov,
)
if self._experimental_param_to_lr is None:
# NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation;
# Send 0D tensor to GPU in `non_blocking` to avoid QPS regression. Remove the gpu
# tensor impl once PT2 supports cpu 0D tensor properly.
lr = torch.tensor(group[LR], dtype=torch.float).to(
self._device, non_blocking=True
)
self._per_group_step(
state_lists,
step,
lr,
beta1,
beta3,
weight_decay,
momentum_param,
dampening,
grafting_config_not_none,
compute_root_inverse,
use_decoupled_weight_decay,
use_bias_correction,
use_grafting_method,
use_nesterov,
)
else:
local_block_info_list = compress_list(
state_lists[DISTRIBUTOR].global_block_info_list,
state_lists[DISTRIBUTOR].distributor_selector,
)
neg_lr_tersors = []
for local_block_info in local_block_info_list:
lr_scalar = self._experimental_param_to_lr(local_block_info.param)
lr = torch.tensor(-lr_scalar, dtype=torch.float).to(
self._device, non_blocking=True
)
neg_lr_tersors.append(lr)

self._per_group_step_experimental_lrs(
state_lists,
step,
neg_lr_tersors,
beta1,
beta3,
weight_decay,
momentum_param,
dampening,
grafting_config_not_none,
compute_root_inverse,
use_decoupled_weight_decay,
use_bias_correction,
use_grafting_method,
use_nesterov,
)

return loss

Expand Down
8 changes: 7 additions & 1 deletion distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
experimental_lrs=(
[float(f) for f in args.experimental_lrs.split(",")]
if args.experimental_lrs
else []
),
experimental_param_to_lr_mapping=args.experimental_param_to_lr_mapping,
)

# checks for checkpointing
Expand All @@ -140,7 +146,7 @@
raise ValueError(
"Distributed checkpointing is only supported with DistributedShampoo!"
)
if args.se_distributed_checkpoint and args.checkpoint_dir is None:
if args.use_distributed_checkpoint and args.checkpoint_dir is None:
raise ValueError(
"Trying to use distributed checkpointing but checkpoint directory is not provided!"
)
Expand Down
Loading