Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Apple MPS acceleration #1129

Open
wants to merge 150 commits into
base: main
Choose a base branch
from

Conversation

ClaudiaComito
Copy link
Contributor

@ClaudiaComito ClaudiaComito commented Mar 29, 2023

LAST EDITED DEC 12 2025
[Note from human: the most important changes in this PR are:

  • Apple MPS are now a valid GPU device in device.py. Both ht.array(..., device="gpu") and ht.array(..., device="mps") are allowed.
  • device attribute introduced for ht.random.permutation
  • added an item to the PR template checklist to test with MPS (manually for now until the CI is expanded, see Expand CI to macos-m1 #1747 )
  • codecov is unhappy but the introduced changes cannot be tested with our current setup.

Still to do:

  • update README actually, it's probably best to update the README just before the next release

Below a copilot summary]

This pull request includes several changes to improve compatibility with Apple's Metal Performance Shaders (MPS) and correct some minor issues. The most important changes include modifications to handle unsupported data types on MPS, updates to unit tests, and minor corrections in documentation.

MPS Compatibility Improvements:

Unit Test Updates:

Minor Corrections:

Reference

Issue/s resolved: #1053

Changes proposed:

Type of change

Memory requirements

Performance

Due Diligence

  • All split configurations tested does not apply
  • Multiple dtypes tested in relevant functions
  • Documentation updated (if needed)
  • Title of PR is suitable for corresponding CHANGELOG entry

Does this change modify the behaviour of other functions? If so, which?

no

@ghost
Copy link

ghost commented Mar 29, 2023

👇 Click on the image for a new way to code review

Review these changes using an interactive CodeSee Map

Legend

CodeSee Map legend

@github-actions
Copy link
Contributor

Thank you for the PR!

@codecov
Copy link

codecov bot commented Mar 29, 2023

Codecov Report

Attention: Patch coverage is 69.67213% with 37 lines in your changes missing coverage. Please review.

Project coverage is 91.99%. Comparing base (443afe3) to head (80a867e).

Files with missing lines Patch % Lines
heat/core/arithmetics.py 72.41% 8 Missing ⚠️
heat/core/tests/test_suites/basic_test.py 52.94% 8 Missing ⚠️
heat/core/devices.py 0.00% 7 Missing ⚠️
heat/core/_operations.py 42.85% 4 Missing ⚠️
heat/core/manipulations.py 86.95% 3 Missing ⚠️
heat/core/statistics.py 87.50% 2 Missing ⚠️
heat/core/dndarray.py 75.00% 1 Missing ⚠️
heat/core/linalg/basics.py 83.33% 1 Missing ⚠️
heat/core/relational.py 75.00% 1 Missing ⚠️
heat/core/signal.py 66.66% 1 Missing ⚠️
... and 1 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1129      +/-   ##
==========================================
- Coverage   92.26%   91.99%   -0.28%     
==========================================
  Files          84       84              
  Lines       12445    12535      +90     
==========================================
+ Hits        11482    11531      +49     
- Misses        963     1004      +41     
Flag Coverage Δ
unit 91.99% <69.67%> (-0.28%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ClaudiaComito ClaudiaComito changed the title Features/1053 support apple silicon gp us Support Apple MPS acceleration Mar 29, 2023
@github-actions
Copy link
Contributor

Thank you for the PR!

@github-actions
Copy link
Contributor

Thank you for the PR!

@ClaudiaComito ClaudiaComito added this to the 1.3.0 milestone Apr 17, 2023
@ClaudiaComito ClaudiaComito self-assigned this Apr 17, 2023
@github-actions
Copy link
Contributor

Thank you for the PR!

@github-actions
Copy link
Contributor

Thank you for the PR!

@github-actions
Copy link
Contributor

Thank you for the PR!

@github-actions
Copy link
Contributor

Thank you for the PR!

@ClaudiaComito
Copy link
Contributor Author

Tests failed; I rerun them to check whether thats just a HW problem.

Thanks, sadly it looks like an actual problem, conveniently without error message. I'll debug it, will probably get to it next week.

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

Copy link
Contributor

Thank you for the PR!

@JuanPedroGHM JuanPedroGHM self-requested a review December 16, 2024 08:58
@ClaudiaComito ClaudiaComito requested review from JuanPedroGHM and removed request for JuanPedroGHM December 16, 2024 08:58
Comment on lines +87 to +91
if self.is_mps:
dtypes = [ht.float32]
else:
dtypes = [ht.float32, ht.float64]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and all subsequent tests that have to filter by system) would be a great target for parametrization (now that we talked about introducing hypothesis and parametrized tests).

A good example on how to skip certain possible parameters based on the os is here

@@ -339,7 +339,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.zeros(1, dtype=torch.int64, device=tdev),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for change? Why not use the default dtype?

@@ -2154,19 +2168,20 @@ def test_triu(self):
self.assertTrue(result.larray[0, -1] == 1)

def test_vdot(self):
a = ht.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]], split=0)
b = ht.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]], split=0)
if not self.is_mps:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should be skipped using unittest.skipIf or pytest.mark.skipif

ht.allclose(q.transpose([0, 1, 3, 2]) @ q, batched_id, atol=1e-6, rtol=1e-6)
)
self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6))
# skip float64 tests on MPS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should be skipped using unittest.skipIf or pytest.mark.skipif

]
rtols = [1e-1, 1e-2, 1e-3]
ranks = [5, 10, 15]
# not testing on MPS for now as torch.norm() is unstable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should be skipped using unittest.skipIf or pytest.mark.skipif

Comment on lines +169 to +171
is_mps = x.larray.is_mps or y.larray.is_mps
if is_mps and result_type is types.float64:
result_type = types.float32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking every time after calling types.result_type, the check could be done inside types.result_type(). This would save a lot of extra if statements and less chance of possibly forgetting to add that.

Comment on lines +119 to +122
if a.larray.is_mps and promoted_type == float64:
# cannot cast to float64 on MPS
promoted_type = float32

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with promote_types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Apple's MPS backend
3 participants