Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] benchmark, this code runs more than 3x slower compare to pytorch #1531

Open
thegodone opened this issue Oct 26, 2024 · 10 comments
Open

Comments

@thegodone
Copy link

thegodone commented Oct 26, 2024

Describe the bug
MLX is much slower than Pytorch & Tensorflow using GPU : (~ 0.46 sec per epoch for Pytorch & Tensorflow versus 1.55 sec for MLX)

To Reproduce

follow instruction and codes from https://github.com/thegodone/apple_ai_model

Expected behavior
MLX is really slower than tensorflow and pytorch.

Desktop:

  • OS: [MacOS 15.1]
  • mlx [0.19.1]
  • pytorch [2.5]
  • tensorflow [2.15]
@awni
Copy link
Member

awni commented Oct 28, 2024

TLDR: with a few fixes + mx.compile it comes down to Time: 0.3940401077270508 compared to PyTorch Time: 0.7042069435119629 on an M1 Max.

More comments on this:

atom_neighbor = [atom_list[i][atom_degree_list[i]] for i in range(batch_size)]
atom_neighbor = mx.stack(atom_neighbor, axis=0)

which is completely un-vectorized and will be very slow.

You can do this in a single indexing op like so:

atom_neighbor = mx.take_along_axis(atom_list[..., None, :], atom_degree_list[..., None], axis=-3)
  • Put the model in eval mode when testing to avoid using dropout (model.eval()) and then back to training mode for training (model.train()).

With the dropout fix + switching those unvectorized ops, the training time goes down to: Time: 0.47445106506347656

Using mx.compile gets it down to Time: 0.3940401077270508

Here's roughly what that looks like:

lr_schedule = cosineannealingwarmrestartfactor(initial_lr, restarts, decay_step, warmup_factor)

optimizer = optim.AdamW(learning_rate=lr_schedule, weight_decay=1**-weight_decay)

def loss_fn(y_hat, y):
    y = mx.reshape(y, y_hat.shape)
    return mx.mean(nn.losses.mse_loss(y_hat, y))

from functools import partial

state = [model, optimizer.state, mx.random.state]

def forward_fn(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
    _, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
    loss = loss_fn(y_hat, labels)
    return loss, y_hat

@partial(mx.compile, inputs=state, outputs=state)
def step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
    loss_and_grad_fn = nn.value_and_grad(model, forward_fn)
    (loss, y_hat), grads = loss_and_grad_fn(
        x_atom=x_atom,
        x_bonds=x_bonds,
        x_atom_index=x_atom_index,
        x_bond_index=x_bond_index,
        x_mask=x_mask,
        labels=labels,
    )
    optimizer.update(model, grads)
    return loss

And in the train loop:

    for counter, train_batch in enumerate(batch_list):
        batch_df = dataset.loc[train_batch,:]
        smiles_list = batch_df.cano_smiles.values
        y_val = mx.array(batch_df[tasks[0]].values)
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom = mx.array(x_atom)
        x_bonds = mx.array(x_bonds)
        x_atom_index = mx.array(x_atom_index)
        x_bond_index = mx.array(x_bond_index)
        x_mask = mx.array(x_mask)


        loss = step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_val)
        mx.eval(state)

@thegodone
Copy link
Author

This is great thanks for the tips and improvement. the gather function aka take_along_axis was a mystery to set up, for me, this is great news.

@thegodone
Copy link
Author

Do you think the GRUCell can even speed up or not ?

@awni
Copy link
Member

awni commented Oct 28, 2024

Do you think the GRUCell can even speed up or not ?

Possibly... but I'm not sure it's the bottleneck in you model right now so I wouldn't over index on that. The best thing to do is figure out where the actual bottleneck is (sometimes this is not so easy) and focus on it.

@thegodone
Copy link
Author

thegodone commented Oct 28, 2024

I did find a good tutorial to setup and use properly Xcode profiler tracing for GPU / MPS. So I am a little blind on the bottleneck investigation.

@thegodone
Copy link
Author

thegodone commented Oct 28, 2024

On my machine M3 Max 128 GB, without your new "dropout speedup" commit, I have now:
Epoch: 42 | loss: 0.00000 | rmse: 0.596 | rmse: 0.664 | rmse: 0.641 | LR: 0.002109 | Time: 0.6805589199066162

@thegodone
Copy link
Author

thegodone commented Oct 28, 2024

thegodone added a commit to thegodone/apple_ai_model that referenced this issue Oct 28, 2024
@thegodone
Copy link
Author

thegodone commented Oct 30, 2024

Now training is very fast 0.285 sec per epoch (using the bernoulli commit), but I have a major issue that I have seen in lot of the cases on other trial. the training gives rarely a good performance while for torch and keras it is more stable and good. This is really a bottleneck to use MLX, as you need to train 10 to 20 time your model to get a good result while torch and keras are systematically good (0.50-0.55). Even if keras is slightly less good than pytorch.

@angeloskath
Copy link
Member

Are you using lr warmup or any scheduler? In general, it is not necessary that the hyperparameters and/or initialization will transfer 1-1 from framework to framework and you might want to play around with these a bit. If there is any reason to suspect a numerical inaccuracy then feel free to open an issue with more information on that.

@thegodone
Copy link
Author

indeed I hope this one #1542 as for me the difference is significative and cannot come from just the lr warmup scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants