-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prototype] Add param-to-lr interface to distributed Shampoo #22
base: main
Are you sure you want to change the base?
[Prototype] Add param-to-lr interface to distributed Shampoo #22
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @shintaro-iwasaki for putting this together!
As context, this work is motivated by the desire to support use cases where many parameter groups may exist. For example, a user may choose to set a different learning rate for a single (or small group of) parameter(s). In this setting, our current distributed code will construct a separate Distributor
instance for each parameter group and call separate communications primitives, which can lead to a high constant overhead from each AllGather
call. Instead, we would like to see how we can potentially fuse parameter groups together automatically for certain cases so that a single Distributor
is used for multiple parameter groups. This will also allow for other things like foreach
operators to be called on multiple parameter groups simultaneously, which should additionally improve performance.
As discussed offline, I don't expect that we will land this particular code, but we will use this as a basis for formalizing the creation of virtual parameter groups. I'm now more convinced that this is easily possible as long as the fields we are fusing together are floats.
Writing comments for myself and @tsunghsienlee.
cc: @tsunghsienlee
@@ -474,6 +479,7 @@ def __init__( | |||
self._shampoo_pt2_compile_config: Optional[ShampooPT2CompileConfig] = ( | |||
shampoo_pt2_compile_config | |||
) | |||
self._experimental_param_to_lr = experimental_param_to_lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to handle support for this properly, we will need to create a function that constructs this mapping automatically from each parameter (within each parameter group) to its learning rate and modify the parameter groups defined typically by torch.optim.Optimizer
(see https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer).
This may need to be moved prior to the super().__init__()
function above, so we can call this on the parameter groups.
@@ -326,6 +330,7 @@ def __init__( | |||
precision_config: Optional[PrecisionConfig] = None, | |||
use_protected_eigh: bool = True, | |||
track_root_inv_residuals: bool = False, | |||
experimental_param_to_lr: Optional[Callable[[torch.Tensor], float]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will not want to expose this directly to the user, but create a flag that merges parameter groups that have different lr
, betas
, beta3
, epsilon
, momentum
, dampening
, or weight_decay
, but share the same fields everywhere else.
use_nesterov: bool, | ||
) -> None: | ||
# Incorporate L2-regularization or (coupled) weight decay if enabled. | ||
# G <- G + lr * weight_decay * W |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to self that we have a typo here; shouldn't have lr * weight_decay
, just weight_decay
.
cc: @tsunghsienlee
|
||
# Updates parameters in distributed fashion. | ||
# If DDP, executes AllGather communication to ensure all parameters are updated after local updates. | ||
torch._foreach_mul_(masked_blocked_search_directions, neg_lrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is exactly the same as the code for _per_group_step_impl
with only the learning rate change, but this experimental path is just requiring us to pass in a list of tensors.
Summary
This PR prototypes a distributed Shampoo optimizer that efficiently executes a model with multiple parameter groups, each having different learning rates.
Problem
PyTorch's optimizer assumes one learning rate per parameter group. To achieve different learning rates for different parameters, we need to create many parameter groups. However, the current Distributed Shampoo implementation cannot handle this efficiently because AllGather occurs per parameter group, leading to too fine-grained communication that hurts performance.
Proposal (Prototype)
This prototype introduces a new interface that allows mapping parameters to learning rates using a function. This experimental interface enables developers to use a single parameter group while still applying different learning rates to different parameters.
Limitations
it's important to note that this is a prototype and has some fundamental problems:
Testing
The DDP Cifar10 example was modified to test this feature.
Learning rates for different params
--experimental-lrs=lr0,lr1,lr2,...
is the new benchmark interface to set learning rates for different parameters (currently in a round-robin manner). The correctness of--experimental-lrs
is tested as follows:--experimental-lrs=0.1,0.1
and--lr=0.1
on a single trainer return the same local lifetime loss--experimental-lrs=0.1,0.2
returns a different local lifetime loss.New interface (proposed by this PR)
Setting
experimental_param_to_lr_mapping
allows the benchmark to create only one parameter group, which should be more efficient in terms of communication. The correctness of--experimental-param-to-lr-mapping
is tested as follows--experimental-lrs=0.1,0.2
, the local lifetime loss is the same with and without--experimental_param_to_lr_mapping=1
on a single trainer.