Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why don't ML engineers use shampoo ?🧴 #3178

Open
G-structure opened this issue Oct 18, 2024 · 2 comments
Open

Why don't ML engineers use shampoo ?🧴 #3178

G-structure opened this issue Oct 18, 2024 · 2 comments

Comments

@G-structure
Copy link

G-structure commented Oct 18, 2024

Hey,

I have been using Meta's Implementation of Distributed Shampoo and am seeing ~20% faster convergence of transformer based models compared to AdamW. Simo Ryu has done some nice investigations into the advantages of Shampoo.

I am looking to use Shampoo and Soap as an optimizer in accelerate but their current implementations introduce some breaking changes.

Focusing on Shampoo for now:

Distributed Shampoo disabled state_dict and load_state_dict in favor of a custom distributed_state_dict , load_distributed_state_dict. Both of which require the models named_parameters() to be passed in as args. More info as to why here

I have a hacky commit here to patch accelerate/optimizers. However I am still forced to bypass accelerate.save() and use dist_checkpoint.save_state_dict() directly since the optimizer in the state_dict needs to have access to the models named_parameters().

state_dict = {
    "model": model.state_dict(),
    "optim": optimizer.distributed_state_dict(key_to_param=model.named_parameters()),
}
dist_checkpoint.save_state_dict(
    state_dict=state_dict,
    storage_writer=dist_checkpoint.FileSystemWriter(CHECKPOINT_DIR),
)

You can see this here in my e2-tts training code. I am able to save the model weights but am not yet able to load them again when using accelerate. This is where I am lost currently.

Also since I don't have access to the named_parameters until accelerate.prepare_model() is called the shampoo optimizer needs to be defined in the model definition, which makes it awkward to switch between optimizers, see here

Ideally id be able to do something like this where I pass in the optimizer as I can with AdamW.

e2tts = E2TTS(
    cond_drop_prob=0.0,
    transformer = dict(
        dim = 512,
        depth = 2,
        heads = 6,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    frac_lengths_mask = (0.7, 0.9)
)


optimizer = DistributedShampoo(
    e2tts.parameters(),
    lr=7.5e-5,
    betas=(0.9, 0.999),
    epsilon=1e-12,
    weight_decay=1e-05,
    max_preconditioner_dim=8192,
    precondition_frequency=100,
    use_decoupled_weight_decay=False,
        grafting_config=AdamGraftingConfig(
            beta2=0.999,
            epsilon=1e-08,
        ),
    )

trainer = E2Trainer(
    e2tts,
    optimizer
)

trainer.train(train_dataset, epochs, batch_size, save_step=50)

ofc when I setup everything with torch ddp, instead of accelerate everything works as intended :/

What would be the best approach for accelerate to support these custom optimizers (ones not part of torch)? My plan currently is to write a ShampooPlugin along the lines of the DeepSpeedPlugin, but it would be nice if the shampoo optimizer could be detected automatically without having to change the accelerate config. I am willing to put in the work to solve this so more projects can benefit from using these new optimizers with accelerate.

Any guidance would be much appreciated. :)

@bghira
Copy link

bghira commented Oct 18, 2024

that optim lacks other torch-specific expectations and like learning rate schedulers, some off-the-wall optimisers just don't work without modification. that distributed zero shampoo optim is a WIP technical prototype and not meant to be used in production, for example.

the SOAP one worked as expected for me. it just needs Closure input on its step. see here: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/optimizers/soap/__init__.py

@bghira
Copy link

bghira commented Oct 18, 2024

other problems of the original optim implementations linked is that they are not functioning with torch.compile and retain very slow performance (exaggerated in SOAP) and high memory overhead (also exaggerated in SOAP) even with ZeRO offload

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants