Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS GRUCell for efficiency #1508

Closed
wants to merge 1 commit into from

Conversation

HendricksJudy
Copy link

Fixes #1500

Implement MPS-specific GRUCell and update GRU class for efficiency.

  • GRUCell Implementation:

    • Add a new GRUCell class in python/mlx/nn/layers/recurrent.py with MPS-specific optimizations.
    • Define the input shape and hidden state shape for the GRUCell.
    • Implement the forward pass for the GRUCell with MPS-specific optimizations.
  • GRU Class Update:

    • Update the GRU class in python/mlx/nn/layers/recurrent.py to use the new GRUCell for improved performance on MPS.
    • Define the input shape and hidden state shape for the GRU class.
    • Implement the forward pass for the GRU class using the GRUCell.
  • Documentation:

    • Update docs/src/python/nn/layers.rst to include the new GRUCell class.
  • Tests:

    • Add tests for the new GRUCell class in python/tests/test_nn.py to ensure correctness and performance improvements.

Love MLX : )

Fixes ml-explore#1500

Implement MPS-specific GRUCell and update GRU class for efficiency.

* **GRUCell Implementation:**
  - Add a new `GRUCell` class in `python/mlx/nn/layers/recurrent.py` with MPS-specific optimizations.
  - Define the input shape and hidden state shape for the `GRUCell`.
  - Implement the forward pass for the `GRUCell` with MPS-specific optimizations.

* **GRU Class Update:**
  - Update the `GRU` class in `python/mlx/nn/layers/recurrent.py` to use the new `GRUCell` for improved performance on MPS.
  - Define the input shape and hidden state shape for the `GRU` class.
  - Implement the forward pass for the `GRU` class using the `GRUCell`.

* **Documentation:**
  - Update `docs/src/python/nn/layers.rst` to include the new `GRUCell` class.

* **Tests:**
  - Add tests for the new `GRUCell` class in `python/tests/test_nn.py` to ensure correctness and performance improvements.

Love MLX : )
@thegodone
Copy link

thegodone commented Oct 22, 2024

This is a cool update, it partially fixes the #1500. Is it possible to made 100% "GPU" full version optimized with MPS core functions similar to the fully cuda version in TF/PyTorch ?

@awni
Copy link
Member

awni commented Oct 22, 2024

Maybe I'm missing something, but I'm not following at all why splitting the "cell" out from the GRU layer is faster.. and also where is "MPS" involved? Could you please explain and/or share some benchmarks?

I don't think we need a GRUCell, but if your implementation is faster it should work just as well to keep it in the layer itself.

@thegodone
Copy link

thegodone commented Oct 22, 2024

GRUCell is required from graph neural networks models like AttentiveFP which is one of the best for molecules. pytorch/aten/src/ATen/native/RNN.cpp

@HendricksJudy
Copy link
Author

This is a cool update, it partially fixes the #1500. Is it possible to made 100% "GPU" full version optimized with MPS core functions similar to the fully cuda version in TF/PyTorch ?

I apologize for my misunderstanding; I will try to make a full version of the 100% "GPU" optimized with MPS core functions.

@thegodone
Copy link

Could you document it in order to guide me for adding other functions please ?

@awni awni closed this Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] adding MPS GRUCell for efficiency
3 participants