Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Apple MPS acceleration #1129

Open
wants to merge 150 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
4e9ed3a
Support Apple MPS, first pass
ClaudiaComito Nov 26, 2022
50297d7
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Mar 18, 2023
febdcfc
Include torch 2.0 in device check
ClaudiaComito Mar 18, 2023
a5642e9
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Mar 29, 2023
8445476
reinstate quick_start.md
ClaudiaComito Mar 29, 2023
8763146
Merge branch 'docs/reinstate-quick-start' into features/1053-support-…
ClaudiaComito Mar 29, 2023
92306e1
[skip ci] edits
ClaudiaComito Mar 29, 2023
b0d7f0f
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Apr 17, 2023
ffa014e
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Apr 18, 2023
6015ebf
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Apr 24, 2023
f3a5ad8
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Apr 27, 2023
a96441b
fix tolerance for torch 2
ClaudiaComito Apr 27, 2023
c7b70c6
implement __array__ method
ClaudiaComito May 18, 2023
ff8af94
test __array__ method
ClaudiaComito May 18, 2023
fc059c2
test __array__ method
ClaudiaComito May 18, 2023
ebf4c51
Merge branch 'features/1153_array_method' into features/1053-support-…
ClaudiaComito May 18, 2023
5c471e2
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito May 22, 2023
34c9347
Merge branch 'features/1053-support-Apple-silicon-GPUs' of github.com…
ClaudiaComito May 22, 2023
f4e3afd
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 16, 2023
681086f
Merge branch 'features/#1117-array-copy-None' into features/1053-supp…
ClaudiaComito Jun 19, 2023
3ad2718
dtype changes for MPS device
ClaudiaComito Jun 19, 2023
b50d9f6
accomodate single prec dtypes for MPS
ClaudiaComito Jun 20, 2023
8ff410a
torch.linalg.inv workaround on MPS
ClaudiaComito Jun 20, 2023
e3bd1d2
Implement MPS-friendly choice of dtypes
ClaudiaComito Jun 20, 2023
7564bac
skip hSVD tests on MPS, torch.norm unstable
ClaudiaComito Jun 20, 2023
cd49de9
cast operands to single precision on MPS
ClaudiaComito Jun 20, 2023
3559295
do not cast to double precision on MPS
ClaudiaComito Jun 20, 2023
cec0e75
skip complex dtype tests on MPS
ClaudiaComito Jun 20, 2023
2361713
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 20, 2023
3fb3667
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Aug 13, 2023
7fbbc45
skip MPI op tests on Apple MPS
ClaudiaComito Aug 16, 2023
6f00be8
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Oct 18, 2023
859b844
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Nov 27, 2023
9414360
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Apr 13, 2024
4cc7401
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito May 28, 2024
e7a8c39
test on float32 only on MPS
ClaudiaComito May 28, 2024
6187804
skip test on MPS, ComplexFloat not supported
ClaudiaComito May 28, 2024
0d464a3
test float32 only on MPS
ClaudiaComito May 28, 2024
b858a89
test float32 only on MPS
ClaudiaComito May 28, 2024
c887a91
edit docs
ClaudiaComito May 28, 2024
a58be6c
non-distr in-place cumprod
ClaudiaComito May 28, 2024
2b71ed2
no float64 on MPS
ClaudiaComito May 28, 2024
83d0136
skip float64 tests on MPS
ClaudiaComito May 28, 2024
a72678d
implement non-distr in-place cumsum
ClaudiaComito May 28, 2024
30cfa8a
catch MPS hypot crash on int64 early
ClaudiaComito May 28, 2024
9e128d6
do not test non-supported dtypes on MPS
ClaudiaComito May 28, 2024
95dcd0e
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito May 28, 2024
cdef743
replace device check with getattr to accomodate scalars
ClaudiaComito May 28, 2024
e6db9ae
skip complex-math tests on MPS and MacOS<14
ClaudiaComito May 29, 2024
00f6791
skip float64 tests on MPS
ClaudiaComito May 29, 2024
6c532f5
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 4, 2024
8e7e2a6
do not test float64 on MPS
ClaudiaComito Jun 4, 2024
6202f3b
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 4, 2024
113e562
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 6, 2024
bf3bbb9
do not test float64 on MPS
ClaudiaComito Jun 6, 2024
47b3253
skip float64 tests on MPS
ClaudiaComito Jun 7, 2024
3e8c8aa
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jun 12, 2024
ad165fa
do not cast to float64 on MPS
ClaudiaComito Jun 12, 2024
d71c9c1
do not test float64 and complex on mps
ClaudiaComito Jun 19, 2024
84e4942
early exit for non-distr flip
ClaudiaComito Jun 19, 2024
1b1d18d
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jul 8, 2024
2587b89
int32 cumsum on MPS
ClaudiaComito Jul 8, 2024
1db5cd6
skip reduce ops on 4+-dim arrays on MPS
ClaudiaComito Jul 8, 2024
7859f27
skip float64 tests on MPS
ClaudiaComito Jul 8, 2024
92fcafe
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jul 10, 2024
391d3e8
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jul 12, 2024
34192f9
skip int64 cumsum test on MPS
ClaudiaComito Jul 12, 2024
0085388
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Jul 17, 2024
f596b73
indexing tests on MPS
ClaudiaComito Jul 17, 2024
015ba20
skip partitioned tests on MPS
ClaudiaComito Jul 18, 2024
4dda4ea
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Sep 9, 2024
0df738a
add is_mps property
ClaudiaComito Sep 10, 2024
e096d33
refactor roll
ClaudiaComito Sep 10, 2024
32d20f2
refactor tests
ClaudiaComito Sep 10, 2024
52b1f14
refactor tests
ClaudiaComito Sep 10, 2024
88740b9
remove unnecessary try/except
ClaudiaComito Sep 12, 2024
79a7e44
allow double precision int input
ClaudiaComito Sep 12, 2024
ead4ac3
simplify MPS heuristics
ClaudiaComito Sep 12, 2024
9a71d7e
update tests
ClaudiaComito Sep 12, 2024
a3a16c4
cum ops now supported on int64
ClaudiaComito Sep 12, 2024
493c877
update tests
ClaudiaComito Sep 12, 2024
8ca15dd
remove ref to torch < 2
ClaudiaComito Sep 12, 2024
26c003d
allow cum ops on int64
ClaudiaComito Sep 12, 2024
cd23eaf
update tests
ClaudiaComito Sep 12, 2024
784d5a1
update tests
ClaudiaComito Sep 12, 2024
2330492
cum ops now supported
ClaudiaComito Sep 13, 2024
fc9b03a
update tests
ClaudiaComito Sep 13, 2024
06b2af7
update tests
ClaudiaComito Sep 13, 2024
c81300a
update tests
ClaudiaComito Sep 30, 2024
7b1fe5b
update test_svd
ClaudiaComito Sep 30, 2024
fb75a2e
update test_svdtools
ClaudiaComito Sep 30, 2024
da63ce8
simplify is_mps
ClaudiaComito Sep 30, 2024
35cea1c
WIP - update test_statistics
ClaudiaComito Sep 30, 2024
b23c6b7
bypass allreduce call for non-distr histc
ClaudiaComito Oct 1, 2024
ae3a96d
adapt test_statistics
ClaudiaComito Oct 1, 2024
d998e25
percentile output dtype on MPS
ClaudiaComito Oct 1, 2024
1a7ac24
adapt test_manipulations to MPS
ClaudiaComito Oct 1, 2024
bffd3db
early out in non-distr cases
ClaudiaComito Oct 1, 2024
22d6e0d
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Oct 1, 2024
5628ea4
update test_complex_math
ClaudiaComito Oct 2, 2024
727cadd
update test_exp
ClaudiaComito Oct 2, 2024
aaff882
update test_factories
ClaudiaComito Oct 2, 2024
8e44f89
update test_io
ClaudiaComito Oct 2, 2024
386cd68
update test_logical
ClaudiaComito Oct 2, 2024
8df2649
skip tests on MPS
ClaudiaComito Oct 2, 2024
550e638
update cumops for MPS
ClaudiaComito Oct 2, 2024
00c6aab
skip print_GPU on MPS
ClaudiaComito Oct 2, 2024
c31f584
support device setting for randperm
ClaudiaComito Oct 9, 2024
6aec0b4
expand randperm docs on device
ClaudiaComito Oct 10, 2024
fcf7a9e
adapt tests to MPS
ClaudiaComito Oct 10, 2024
7ca9548
adapt test_rounding to MPS
ClaudiaComito Oct 10, 2024
3c94c88
do not cast to float64 on MPS
ClaudiaComito Oct 10, 2024
9d49269
adapt test_signal to MPS
ClaudiaComito Oct 10, 2024
827e231
adatpt basic_tests to MPS
ClaudiaComito Oct 11, 2024
222e8f0
adapt test_trigonometrics to MPS
ClaudiaComito Oct 14, 2024
e9d4eff
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Nov 11, 2024
61daa63
adapt TestRSVD to MPS
ClaudiaComito Nov 11, 2024
f7db54e
adapt test_randint to MPS
ClaudiaComito Nov 11, 2024
2e0b686
do not cast to float64 on MPS
ClaudiaComito Nov 11, 2024
21de13e
adapt test_types to MPS
ClaudiaComito Nov 11, 2024
509cb3b
skip MPS FFT tests on MacOS < 14
ClaudiaComito Nov 26, 2024
2bd212a
skip Threefry tests on MPS
ClaudiaComito Nov 26, 2024
e9db00e
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Nov 26, 2024
10ec173
adatpt test_fft to MPS
ClaudiaComito Nov 29, 2024
2993c11
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Nov 29, 2024
a20fab9
skip float64 test_isvd on MPS
ClaudiaComito Nov 29, 2024
b3b72c5
adapt test_pca to MPS
ClaudiaComito Nov 29, 2024
718b47e
no sparse tests on MPS
ClaudiaComito Nov 29, 2024
6177f94
adapt to MPS
ClaudiaComito Nov 29, 2024
667f436
adapt to MPS
ClaudiaComito Nov 29, 2024
11d87b1
remove print statement
ClaudiaComito Nov 29, 2024
06db453
skip float64 tests on MPS
ClaudiaComito Nov 29, 2024
a2cfa29
skip complex128 on MPS
ClaudiaComito Nov 29, 2024
0875b5c
fix CPU tests
ClaudiaComito Dec 2, 2024
cd1b2b4
fix CPU tests
ClaudiaComito Dec 2, 2024
5147c7b
update test_solver
ClaudiaComito Dec 3, 2024
0589f93
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Dec 3, 2024
0528481
skip float64 batch-qr test on MPS
ClaudiaComito Dec 3, 2024
0efad5e
adapt test_permutation
ClaudiaComito Dec 10, 2024
0a6b6ed
Merge branch 'main' into features/1053-support-Apple-silicon-GPUs
ClaudiaComito Dec 10, 2024
164b7b7
return inv as DNDarray, not Tensor
ClaudiaComito Dec 10, 2024
c7cdf34
skip DMD tests on MPS
ClaudiaComito Dec 11, 2024
dc2ab84
increase allclose tolerance for test_inv
ClaudiaComito Dec 11, 2024
6907c71
skip line formatting
ClaudiaComito Dec 11, 2024
58cc44f
skip line formatting
ClaudiaComito Dec 11, 2024
1e902d1
skip line formatting
ClaudiaComito Dec 11, 2024
71e83c5
debugging test_sort on AMD
ClaudiaComito Dec 12, 2024
fa3e900
update test_iris
ClaudiaComito Dec 12, 2024
103dc7e
indices sorting workaround for CUDA
ClaudiaComito Dec 12, 2024
80a867e
update PR template, docs
ClaudiaComito Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Implementation:
- [ ] unit tests: all split configurations tested
- [ ] unit tests: multiple dtypes tested
- [ ] **NEW** unit tests: MPS tested (1 MPI process, 1 GPU)
- [ ] benchmarks: created for new functionality
- [ ] benchmarks: performance improved or maintained
- [ ] documentation updated where needed
Expand Down
1 change: 0 additions & 1 deletion heat/classification/kneighborsclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def predict(self, x: DNDarray) -> DNDarray:
"""
distances = self.effective_metric_(x, self.x)
_, indices = ht.topk(distances, self.n_neighbors, largest=False)

predictions = self.y[indices.flatten()]
predictions.balance_()
predictions = ht.reshape(predictions, (indices.gshape + (self.y.gshape[1],)))
Expand Down
7 changes: 6 additions & 1 deletion heat/cluster/tests/test_batchparallelclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,19 @@ def test_get_and_set_params(self):
self.assertEqual(10, parallelclusterer.n_clusters)

def test_spherical_clusters(self):
if self.is_mps:
dtypes = [ht.float32]
else:
dtypes = [ht.float32, ht.float64]

Comment on lines +87 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and all subsequent tests that have to filter by system) would be a great target for parametrization (now that we talked about introducing hypothesis and parametrized tests).

A good example on how to skip certain possible parameters based on the os is here

for ParallelClusterer in [ht.cluster.BatchParallelKMeans, ht.cluster.BatchParallelKMedians]:
if ParallelClusterer is ht.cluster.BatchParallelKMeans:
ppinitkws = ["k-means++"]
elif ParallelClusterer is ht.cluster.BatchParallelKMedians:
ppinitkws = ["k-medians++"]
for seed in [1, None]:
n = 20 * ht.MPI_WORLD.size
for dtype in [ht.float32, ht.float64]:
for dtype in dtypes:
data = create_spherical_dataset(
num_samples_cluster=n,
radius=1.0,
Expand Down
9 changes: 7 additions & 2 deletions heat/cluster/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ def test_spherical_clusters(self):

# different datatype
n = 20 * ht.MPI_WORLD.size
if self.is_mps:
# MPS does not support float64
dtype = ht.float32
else:
dtype = ht.float64
data = create_spherical_dataset(
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=ht.float64, random_state=seed
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=dtype, random_state=seed
)
kmeans = ht.cluster.KMeans(n_clusters=4, init="kmeans++")
kmeans.fit(data)
self.assertIsInstance(kmeans.cluster_centers_, ht.DNDarray)
self.assertEqual(kmeans.cluster_centers_.shape, (4, 3))

# on Ints (different radius, offset and datatype
# on Ints (different radius, offset and datatype)
data = create_spherical_dataset(
num_samples_cluster=n, radius=10.0, offset=40.0, dtype=ht.int32, random_state=seed
)
Expand Down
7 changes: 6 additions & 1 deletion heat/cluster/tests/test_kmedians.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,13 @@ def test_spherical_clusters(self):

# different datatype
n = 20 * ht.MPI_WORLD.size
# MPS does not support float64
if self.is_mps:
dtype = ht.float32
else:
dtype = ht.float64
data = create_spherical_dataset(
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=ht.float64, random_state=seed
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=dtype, random_state=seed
)
kmedians = ht.cluster.KMedians(n_clusters=4, init="kmedians++")
kmedians.fit(data)
Expand Down
7 changes: 6 additions & 1 deletion heat/cluster/tests/test_kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ def test_spherical_clusters(self):

# different datatype
n = 20 * ht.MPI_WORLD.size
# MPS does not support float64
if self.is_mps:
dtype = ht.float32
else:
dtype = ht.float64
data = create_spherical_dataset(
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=ht.float64, random_state=seed
num_samples_cluster=n, radius=1.0, offset=4.0, dtype=dtype, random_state=seed
)
kmedoid = ht.cluster.KMedoids(n_clusters=4, init="kmedoids++")
kmedoid.fit(data)
Expand Down
85 changes: 44 additions & 41 deletions heat/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import heat as ht
import torch

from ...core.tests.test_suites.basic_test import TestCase

Expand Down Expand Up @@ -35,49 +36,51 @@ def test_get_and_set_params(self):
self.assertEqual(10, spectral.n_clusters)

def test_fit_iris(self):
# get some test data
iris = ht.load("heat/datasets/iris.csv", sep=";", split=0)
m = 10
# fit the clusters
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m
)
spectral.fit(iris)
self.assertIsInstance(spectral.labels_, ht.DNDarray)
# skip on MPS, matmul on ComplexFloat not supported as of PyTorch 2.5
if not self.is_mps:
# get some test data
iris = ht.load("heat/datasets/iris.csv", sep=";", split=0)
m = 10
# fit the clusters
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m
)
spectral.fit(iris)
self.assertIsInstance(spectral.labels_, ht.DNDarray)

spectral = ht.cluster.Spectral(
metric="euclidean",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
spectral = ht.cluster.Spectral(
metric="euclidean",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

spectral = ht.cluster.Spectral(
gamma=0.1,
metric="rbf",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
spectral = ht.cluster.Spectral(
gamma=0.1,
metric="rbf",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1}
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1}
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

# Errors
with self.assertRaises(NotImplementedError):
spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m)
# Errors
with self.assertRaises(NotImplementedError):
spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m)

iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1)
spectral = ht.cluster.Spectral(n_lanczos=20)
with self.assertRaises(NotImplementedError):
spectral.fit(iris_split)
iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1)
spectral = ht.cluster.Spectral(n_lanczos=20)
with self.assertRaises(NotImplementedError):
spectral.fit(iris_split)
9 changes: 9 additions & 0 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ def __get_out_params(target, other=None, map=None):
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)

# MPS does not support float64
if t1.larray.is_mps and promoted_type == torch.float64:
promoted_type = torch.float32

result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs)

if out is None and where is True:
Expand Down Expand Up @@ -282,6 +286,9 @@ def __cum_op(

if dtype is not None:
dtype = types.canonical_heat_type(dtype)
if x.larray.is_mps and dtype == types.float64:
warnings.warn("MPS does not support float64, will cast to float32")
dtype = types.float32

if out is not None:
sanitation.sanitize_out(out, x.shape, x.split, x.device)
Expand Down Expand Up @@ -369,6 +376,8 @@ def __local_op(
# we need floating point numbers here, due to PyTorch only providing sqrt() implementation for float32/64
if not no_cast:
promoted_type = types.promote_types(x.dtype, types.float32)
if promoted_type is types.float64 and x.device.torch_device.startswith("mps"):
promoted_type = types.float32
torch_type = promoted_type.torch_type()
else:
torch_type = x.larray.dtype
Expand Down
53 changes: 45 additions & 8 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,14 @@ def wrap_cumprod_(a: torch.Tensor, b: int, out=None, dtype=None) -> torch.Tensor
def wrap_mul_(a: torch.Tensor, b: torch.Tensor, out=None) -> torch.Tensor:
return a.mul_(b)

axis = stride_tricks.sanitize_axis(t.shape, axis)
if axis is None:
raise NotImplementedError("cumprod_ is not implemented for axis=None")

if not t.is_distributed():
t.larray.cumprod_(dim=axis)
return t

return _operations.__cum_op(t, wrap_cumprod_, MPI.PROD, wrap_mul_, 1, axis, dtype=None, out=t)


Expand Down Expand Up @@ -891,6 +899,14 @@ def wrap_cumsum_(a: torch.Tensor, b: int, out=None, dtype=None) -> torch.Tensor:
def wrap_add_(a: torch.Tensor, b: torch.Tensor, out=None) -> torch.Tensor:
return a.add_(b)

axis = stride_tricks.sanitize_axis(t.shape, axis)
if axis is None:
raise NotImplementedError("cumsum_ is not implemented for axis=None")

if not t.is_distributed():
t.larray.cumsum_(dim=axis)
return t

return _operations.__cum_op(t, wrap_cumsum_, MPI.SUM, wrap_add_, 0, axis, dtype=None, out=t)


Expand Down Expand Up @@ -1622,8 +1638,8 @@ def wrap_gcd_(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:


def hypot(
a: DNDarray,
b: DNDarray,
t1: DNDarray,
t2: DNDarray,
/,
out: Optional[DNDarray] = None,
*,
Expand All @@ -1635,9 +1651,9 @@ def hypot(

Parameters
----------
a: DNDarray
t1: DNDarray
The first input array
b: DNDarray
t2: DNDarray
the second input array
out: DNDarray, optional
The output array. It must have a shape that the inputs broadcast to and matching split axis.
Expand All @@ -1656,12 +1672,22 @@ def hypot(
>>> ht.hypot(a,b)
DNDarray([2.2361, 3.6056, 3.6056], dtype=ht.float32, device=cpu:0, split=None)
"""
# catch int64 operation crash on MPS. TODO: issue still persists in 2.3.0, check 2.4, report to PyTorch
t1_ismps = getattr(getattr(t1, "device", "cpu"), "torch_device", "cpu").startswith("mps")
t2_ismps = getattr(getattr(t2, "device", "cpu"), "torch_device", "cpu").startswith("mps")
if t1_ismps or t2_ismps:
t1_isint64 = getattr(t1, "dtype", None) == types.int64
t2_isint64 = getattr(t2, "dtype", None) == types.int64
if t1_isint64 or t2_isint64:
raise TypeError(
f"hypot on MPS does not support int64 dtype, got {t1.dtype}, {t2.dtype}"
)

try:
res = _operations.__binary_op(torch.hypot, a, b, out, where)
res = _operations.__binary_op(torch.hypot, t1, t2, out, where)
except RuntimeError:
# every other possibility is caught by __binary_op
raise TypeError(f"Not implemented for array dtype, got {a.dtype}, {b.dtype}")

raise TypeError(f"hypot on CPU does not support Int dtype, got {t1.dtype}, {t2.dtype}")
return res


Expand Down Expand Up @@ -1704,14 +1730,25 @@ def hypot_(t1: DNDarray, t2: DNDarray) -> DNDarray:
def wrap_hypot_(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a.hypot_(b)

# catch int64 operation crash on MPS
t1_ismps = getattr(getattr(t1, "device", "cpu"), "torch_device", "cpu").startswith("mps")
t2_ismps = getattr(getattr(t2, "device", "cpu"), "torch_device", "cpu").startswith("mps")
if t1_ismps or t2_ismps:
t1_isint64 = getattr(t1, "dtype", None) == types.int64
t2_isint64 = getattr(t2, "dtype", None) == types.int64
if t1_isint64 or t2_isint64:
raise TypeError(
f"hypot_ on MPS does not support int64 dtype, got {t1.dtype}, {t2.dtype}"
)

try:
return _operations.__binary_op(wrap_hypot_, t1, t2, out=t1)
except NotImplementedError:
raise ValueError(
f"In-place operation not allowed: operands are distributed along different axes. \n Operand 1 with shape {t1.shape} is split along axis {t1.split}. \n Operand 2 with shape {t2.shape} is split along axis {t2.split}."
)
except RuntimeError:
raise TypeError(f"Not implemented for array dtype, got {t1.dtype}, {t2.dtype}")
raise TypeError(f"hypot on CPU does not support Int dtype, got {t1.dtype}, {t2.dtype}")


DNDarray.hypot_ = hypot_
Expand Down
28 changes: 26 additions & 2 deletions heat/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

class Device:
"""
Implements a compute device. HeAT can run computations on different compute devices or backends.
Implements a compute device. Heat can run computations on different compute devices or backends.
A device describes the device type and id on which said computation should be carried out.

Parameters
----------
device_type : str
Represents HeAT's device name
Represents Heat's device name
device_id : int
The device id
torch_device : str
Expand All @@ -34,6 +34,8 @@ class Device:
device(cpu:0)
>>> ht.Device("gpu", 0, "cuda:0")
device(gpu:0)
>>> ht.Device("gpu", 0, "mps:0") # on Apple M1/M2
device(gpu:0)
"""

def __init__(self, device_type: str, device_id: int, torch_device: str):
Expand Down Expand Up @@ -133,6 +135,28 @@ def __eq__(self, other: Any) -> bool:
# the GPU device should be exported as global symbol
__all__.append("gpu")

elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
# Apple MPS available
gpu_id = 0
# create a new GPU device
gpu = Device("gpu", gpu_id, "mps:{}".format(gpu_id))
"""
The standard GPU Device on Apple M1/M2

Examples
--------
>>> ht.cpu
device(cpu:0)
>>> ht.ones((2, 3), device=ht.gpu)
DNDarray([[1., 1., 1.],
[1., 1., 1.]], dtype=ht.float32, device=mps:0, split=None)
"""
# add a GPU device string
__device_mapping[gpu.device_type] = gpu
__device_mapping["mps"] = gpu
# the GPU device should be exported as global symbol
__all__.append("gpu")


def get_device() -> Device:
"""
Expand Down
Loading
Loading