Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prototype] Add param-to-lr interface to distributed Shampoo #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

shintaro-iwasaki
Copy link
Contributor

Summary

This PR prototypes a distributed Shampoo optimizer that efficiently executes a model with multiple parameter groups, each having different learning rates.

Problem

PyTorch's optimizer assumes one learning rate per parameter group. To achieve different learning rates for different parameters, we need to create many parameter groups. However, the current Distributed Shampoo implementation cannot handle this efficiently because AllGather occurs per parameter group, leading to too fine-grained communication that hurts performance.

Proposal (Prototype)

This prototype introduces a new interface that allows mapping parameters to learning rates using a function. This experimental interface enables developers to use a single parameter group while still applying different learning rates to different parameters.

Limitations

it's important to note that this is a prototype and has some fundamental problems:

  • Unusual optimizer interface
  • Burden of providing a mapping function for users
  • Lack of checkpointing support (not tested)
  • No proper unit tests
  • PT2 is not supported (not tested)
  • No performance test has been conducted.

Testing

The DDP Cifar10 example was modified to test this feature.

COMMON_ARGS="-m distributed_shampoo.examples.ddp_cifar10_example --optimizer-type DISTRIBUTED_SHAMPOO --precondition-frequency 100 --grafting-type ADAM --num-trainers-per-group -1 --use-bias-correction --use-decoupled-weight-decay --use-merge-dims --epochs 5 --local-batch-size 128"

Learning rates for different params
--experimental-lrs=lr0,lr1,lr2,... is the new benchmark interface to set learning rates for different parameters (currently in a round-robin manner). The correctness of --experimental-lrs is tested as follows:

  1. Both --experimental-lrs=0.1,0.1 and --lr=0.1 on a single trainer return the same local lifetime loss
  2. --experimental-lrs=0.1,0.2 returns a different local lifetime loss.
$ torchrun --standalone --nnodes=1 --nproc_per_node=1 ${COMMON_ARGS} --lr="0.1"
...
INFO: Epoch: 4 | Iteration: 1955 | Local Lifetime Loss: 55.04213333129883 | Local Window Loss: 112.79435729980469

$ torchrun --standalone --nnodes=1 --nproc_per_node=1 ${COMMON_ARGS} --experimental-lrs="0.1,0.1"
...
INFO: Epoch: 4 | Iteration: 1955 | Local Lifetime Loss: 55.04213333129883 | Local Window Loss: 112.79435729980469

$ torchrun --standalone --nnodes=1 --nproc_per_node=1 ${COMMON_ARGS} --experimental-lrs="0.1,0.2"
...
INFO: Epoch: 4 | Iteration: 1955 | Local Lifetime Loss: 115.76996612548828 | Local Window Loss: 249.4604034423828

New interface (proposed by this PR)
Setting experimental_param_to_lr_mapping allows the benchmark to create only one parameter group, which should be more efficient in terms of communication. The correctness of --experimental-param-to-lr-mapping is tested as follows

  1. With --experimental-lrs=0.1,0.2, the local lifetime loss is the same with and without --experimental_param_to_lr_mapping=1 on a single trainer.
$ torchrun --standalone --nnodes=1 --nproc_per_node=1 ${COMMON_ARGS} --experimental-lrs="0.1,0.2" --experimental-param-to-lr-mapping=1
...
INFO: Epoch: 4 | Iteration: 1955 | Local Lifetime Loss: 115.76996612548828 | Local Window Loss: 249.4604034423828

$ torchrun --standalone --nnodes=1 --nproc_per_node=1 ${COMMON_ARGS} --experimental-lrs="0.1,0.2" --experimental-param-to-lr-mapping=0
...
INFO: Epoch: 4 | Iteration: 1955 | Local Lifetime Loss: 115.76996612548828 | Local Window Loss: 249.4604034423828
  1. Multiple trainers. All the following matched the results (2 trainers).
$ torchrun --standalone --nnodes=1 --nproc_per_node=2 ${COMMON_ARGS} --lr=0.1
...
Epoch: 4 | Iteration: 980 | Global Lifetime Loss: 53.14086151123047 | Global Window Loss: 4.617415428161621 

$ torchrun --standalone --nnodes=1 --nproc_per_node=2 ${COMMON_ARGS} --experimental-lrs="0.1,0.1" --experimental-param-to-lr-mapping=1
...
Epoch: 4 | Iteration: 980 | Global Lifetime Loss: 53.14086151123047 | Global Window Loss: 4.617415428161621 

$ torchrun --standalone --nnodes=1 --nproc_per_node=2 ${COMMON_ARGS} --experimental-lrs="0.1,0.1" --experimental-param-to-lr-mapping=0
...
Epoch: 4 | Iteration: 980 | Global Lifetime Loss: 53.14086151123047 | Global Window Loss: 4.617415428161621

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 14, 2024
Copy link
Contributor

@hjmshi hjmshi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shintaro-iwasaki for putting this together!

As context, this work is motivated by the desire to support use cases where many parameter groups may exist. For example, a user may choose to set a different learning rate for a single (or small group of) parameter(s). In this setting, our current distributed code will construct a separate Distributor instance for each parameter group and call separate communications primitives, which can lead to a high constant overhead from each AllGather call. Instead, we would like to see how we can potentially fuse parameter groups together automatically for certain cases so that a single Distributor is used for multiple parameter groups. This will also allow for other things like foreach operators to be called on multiple parameter groups simultaneously, which should additionally improve performance.

As discussed offline, I don't expect that we will land this particular code, but we will use this as a basis for formalizing the creation of virtual parameter groups. I'm now more convinced that this is easily possible as long as the fields we are fusing together are floats.

Writing comments for myself and @tsunghsienlee.

cc: @tsunghsienlee

@@ -474,6 +479,7 @@ def __init__(
self._shampoo_pt2_compile_config: Optional[ShampooPT2CompileConfig] = (
shampoo_pt2_compile_config
)
self._experimental_param_to_lr = experimental_param_to_lr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to handle support for this properly, we will need to create a function that constructs this mapping automatically from each parameter (within each parameter group) to its learning rate and modify the parameter groups defined typically by torch.optim.Optimizer (see https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer).

This may need to be moved prior to the super().__init__() function above, so we can call this on the parameter groups.

@@ -326,6 +330,7 @@ def __init__(
precision_config: Optional[PrecisionConfig] = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
experimental_param_to_lr: Optional[Callable[[torch.Tensor], float]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not want to expose this directly to the user, but create a flag that merges parameter groups that have different lr, betas, beta3, epsilon, momentum, dampening, or weight_decay, but share the same fields everywhere else.

use_nesterov: bool,
) -> None:
# Incorporate L2-regularization or (coupled) weight decay if enabled.
# G <- G + lr * weight_decay * W
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to self that we have a typo here; shouldn't have lr * weight_decay, just weight_decay.

cc: @tsunghsienlee


# Updates parameters in distributed fashion.
# If DDP, executes AllGather communication to ensure all parameters are updated after local updates.
torch._foreach_mul_(masked_blocked_search_directions, neg_lrs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is exactly the same as the code for _per_group_step_impl with only the learning rate change, but this experimental path is just requiring us to pass in a list of tensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants