-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] benchmark, this code runs more than 3x slower compare to pytorch #1531
Comments
TLDR: with a few fixes + mx.compile it comes down to More comments on this:
atom_neighbor = [atom_list[i][atom_degree_list[i]] for i in range(batch_size)]
atom_neighbor = mx.stack(atom_neighbor, axis=0) which is completely un-vectorized and will be very slow. You can do this in a single indexing op like so: atom_neighbor = mx.take_along_axis(atom_list[..., None, :], atom_degree_list[..., None], axis=-3)
With the dropout fix + switching those unvectorized ops, the training time goes down to: Using Here's roughly what that looks like: lr_schedule = cosineannealingwarmrestartfactor(initial_lr, restarts, decay_step, warmup_factor)
optimizer = optim.AdamW(learning_rate=lr_schedule, weight_decay=1**-weight_decay)
def loss_fn(y_hat, y):
y = mx.reshape(y, y_hat.shape)
return mx.mean(nn.losses.mse_loss(y_hat, y))
from functools import partial
state = [model, optimizer.state, mx.random.state]
def forward_fn(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
_, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
loss = loss_fn(y_hat, labels)
return loss, y_hat
@partial(mx.compile, inputs=state, outputs=state)
def step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
loss_and_grad_fn = nn.value_and_grad(model, forward_fn)
(loss, y_hat), grads = loss_and_grad_fn(
x_atom=x_atom,
x_bonds=x_bonds,
x_atom_index=x_atom_index,
x_bond_index=x_bond_index,
x_mask=x_mask,
labels=labels,
)
optimizer.update(model, grads)
return loss And in the train loop: for counter, train_batch in enumerate(batch_list):
batch_df = dataset.loc[train_batch,:]
smiles_list = batch_df.cano_smiles.values
y_val = mx.array(batch_df[tasks[0]].values)
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
x_atom = mx.array(x_atom)
x_bonds = mx.array(x_bonds)
x_atom_index = mx.array(x_atom_index)
x_bond_index = mx.array(x_bond_index)
x_mask = mx.array(x_mask)
loss = step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_val)
mx.eval(state) |
This is great thanks for the tips and improvement. the gather function aka take_along_axis was a mystery to set up, for me, this is great news. |
Do you think the GRUCell can even speed up or not ? |
Possibly... but I'm not sure it's the bottleneck in you model right now so I wouldn't over index on that. The best thing to do is figure out where the actual bottleneck is (sometimes this is not so easy) and focus on it. |
I did find a good tutorial to setup and use properly Xcode profiler tracing for GPU / MPS. So I am a little blind on the bottleneck investigation. |
On my machine M3 Max 128 GB, without your new "dropout speedup" commit, I have now: |
this is based on ml-explore/mlx#1531
Now training is very fast 0.285 sec per epoch (using the bernoulli commit), but I have a major issue that I have seen in lot of the cases on other trial. the training gives rarely a good performance while for torch and keras it is more stable and good. This is really a bottleneck to use MLX, as you need to train 10 to 20 time your model to get a good result while torch and keras are systematically good (0.50-0.55). Even if keras is slightly less good than pytorch. |
Are you using lr warmup or any scheduler? In general, it is not necessary that the hyperparameters and/or initialization will transfer 1-1 from framework to framework and you might want to play around with these a bit. If there is any reason to suspect a numerical inaccuracy then feel free to open an issue with more information on that. |
indeed I hope this one #1542 as for me the difference is significative and cannot come from just the lr warmup scheduler. |
Describe the bug
MLX is much slower than Pytorch & Tensorflow using GPU : (~ 0.46 sec per epoch for Pytorch & Tensorflow versus 1.55 sec for MLX)
To Reproduce
follow instruction and codes from https://github.com/thegodone/apple_ai_model
Expected behavior
MLX is really slower than tensorflow and pytorch.
Desktop:
The text was updated successfully, but these errors were encountered: