Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS GRUCell for efficiency #1508

Closed
wants to merge 1 commit into from

Commits on Oct 21, 2024

  1. Add MPS GRUCell for efficiency

    Fixes ml-explore#1500
    
    Implement MPS-specific GRUCell and update GRU class for efficiency.
    
    * **GRUCell Implementation:**
      - Add a new `GRUCell` class in `python/mlx/nn/layers/recurrent.py` with MPS-specific optimizations.
      - Define the input shape and hidden state shape for the `GRUCell`.
      - Implement the forward pass for the `GRUCell` with MPS-specific optimizations.
    
    * **GRU Class Update:**
      - Update the `GRU` class in `python/mlx/nn/layers/recurrent.py` to use the new `GRUCell` for improved performance on MPS.
      - Define the input shape and hidden state shape for the `GRU` class.
      - Implement the forward pass for the `GRU` class using the `GRUCell`.
    
    * **Documentation:**
      - Update `docs/src/python/nn/layers.rst` to include the new `GRUCell` class.
    
    * **Tests:**
      - Add tests for the new `GRUCell` class in `python/tests/test_nn.py` to ensure correctness and performance improvements.
    
    Love MLX : )
    HendricksJudy committed Oct 21, 2024
    Configuration menu
    Copy the full SHA
    5233cfd View commit details
    Browse the repository at this point in the history