Skip to content

Commit bdd8f51

Browse files
razarmehrpytorchmergebot
authored andcommitted
[MPS] Add Python Module Bindings for the MPS backend (pytorch#94417)
- This PR is a prerequisite for the upcoming Memory Leak Detection PR. - Enable global manual seeding via `torch.manual_seed()` + test case - Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]` - Added some test cases in test_mps.py - Added `mps.rst` to document the `torch.mps` module. - Fixed the failure with `test_public_bindings.py` Description of new files added: - `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`. - `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module. Pull Request resolved: pytorch#94417 Approved by: https://github.com/albanD
1 parent a0f9abd commit bdd8f51

15 files changed

+262
-16
lines changed

Diff for: aten/src/ATen/detail/MPSHooksInterface.h

+8
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,21 @@ struct TORCH_API MPSHooksInterface {
2828
return false;
2929
}
3030

31+
virtual bool isOnMacOS13orNewer() const {
32+
AT_ERROR("MPS backend is not available.");
33+
}
34+
3135
virtual const Generator& getDefaultMPSGenerator() const {
3236
AT_ERROR("Cannot get default MPS generator without MPS backend.");
3337
}
3438

3539
virtual Allocator* getMPSDeviceAllocator() const {
3640
AT_ERROR("MPSDeviceAllocator requires MPS.");
3741
}
42+
43+
virtual void deviceSynchronize() const {
44+
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
45+
}
3846
};
3947

4048
struct TORCH_API MPSHooksArgs {};

Diff for: aten/src/ATen/mps/MPSDevice.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class TORCH_API MPSDevice {
7979

8080
TORCH_API bool is_available();
8181
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
82-
82+
TORCH_API void device_synchronize();
8383
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
8484

8585
} // namespace mps

Diff for: aten/src/ATen/mps/MPSDevice.mm

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/util/CallOnce.h>
44

55
#include <ATen/mps/MPSDevice.h>
6+
#include <ATen/mps/MPSStream.h>
67
#include <ATen/mps/MPSAllocatorInterface.h>
78
#include <ATen/mps/IndexKernels.h>
89

@@ -122,5 +123,9 @@ bool is_macos_13_or_newer(MacOSVersion version) {
122123
return MPSDevice::getInstance()->isMacOS13Plus(version);
123124
}
124125

126+
void device_synchronize() {
127+
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
128+
}
129+
125130
} // namespace mps
126131
} // namespace at

Diff for: aten/src/ATen/mps/MPSHooks.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const {
1616
return at::mps::is_available();
1717
}
1818

19+
bool MPSHooks::isOnMacOS13orNewer() const {
20+
return at::mps::is_macos_13_or_newer();
21+
}
22+
1923
Allocator* MPSHooks::getMPSDeviceAllocator() const {
2024
return at::mps::GetMPSAllocator();
2125
}
@@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
2428
return at::mps::detail::getDefaultMPSGenerator();
2529
}
2630

31+
void MPSHooks::deviceSynchronize() const {
32+
at::mps::device_synchronize();
33+
}
34+
2735
using at::MPSHooksRegistry;
2836
using at::RegistererMPSHooksRegistry;
2937

Diff for: aten/src/ATen/mps/MPSHooks.h

+2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface {
1313
MPSHooks(at::MPSHooksArgs) {}
1414
void initMPS() const override;
1515
bool hasMPS() const override;
16+
bool isOnMacOS13orNewer() const override;
1617
Allocator* getMPSDeviceAllocator() const override;
1718
const Generator& getDefaultMPSGenerator() const override;
19+
void deviceSynchronize() const override;
1820
};
1921

2022
}} // at::mps

Diff for: build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ libtorch_python_core_sources = [
822822
"torch/csrc/dynamo/guards.cpp",
823823
"torch/csrc/dynamo/init.cpp",
824824
"torch/csrc/functorch/init.cpp",
825+
"torch/csrc/mps/Module.cpp",
825826
"torch/csrc/jit/backends/backend_init.cpp",
826827
"torch/csrc/jit/python/init.cpp",
827828
"torch/csrc/jit/passes/onnx.cpp",

Diff for: docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Features described in this documentation are classified by release status:
8181
torch.autograd <autograd>
8282
torch.library <library>
8383
cuda
84+
mps
8485
torch.backends <backends>
8586
torch.distributed <distributed>
8687
torch.distributed.algorithms.join <distributed.algorithms.join>

Diff for: docs/source/mps.rst

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
torch.mps
2+
===================================
3+
.. automodule:: torch.mps
4+
.. currentmodule:: torch.mps
5+
6+
.. autosummary::
7+
:toctree: generated
8+
:nosignatures:
9+
10+
synchronize
11+
get_rng_state
12+
set_rng_state
13+
manual_seed
14+
seed

Diff for: test/test_mps.py

+39
Original file line numberDiff line numberDiff line change
@@ -5972,6 +5972,45 @@ def test_mps_generator(self):
59725972
mps_x = torch.randn(5, device='mps', generator=g_mps)
59735973
self.assertEqual(mps_x, mps_y)
59745974

5975+
def test_default_mps_generator(self):
5976+
# manual seeding on the "default" MPS generator using
5977+
# the global torch.manual_seed()
5978+
torch.manual_seed(230)
5979+
mps_x = torch.randn(5, device='mps')
5980+
# manual seeding using torch.mps.manual_seed()
5981+
# which should set the "default" MPS generator
5982+
# like the global torch.manual_seed()
5983+
torch.mps.manual_seed(230)
5984+
mps_y = torch.randn(5, device='mps')
5985+
# seed values were the same, so the random tensor contents should match
5986+
self.assertEqual(mps_x, mps_y)
5987+
5988+
# save the default generator's state to restore it later
5989+
g_state = torch.mps.get_rng_state()
5990+
5991+
# generate random numbers without seeding
5992+
mps_x = torch.randn(5, device='mps')
5993+
# in this case, the random results must differ from the last generated random results
5994+
self.assertNotEqual(mps_x, mps_y)
5995+
5996+
# restore the previously saved state, and the results should match again
5997+
torch.mps.set_rng_state(g_state)
5998+
mps_x = torch.randn(5, device='mps')
5999+
self.assertEqual(mps_x, mps_y)
6000+
6001+
def test_device_synchronize(self):
6002+
# just running some ops each followed by a synchronize to wait for
6003+
# MPS stream to finish running each of them
6004+
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
6005+
.to(device='mps', dtype=torch.float)
6006+
6007+
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
6008+
torch.mps.synchronize()
6009+
x = net1(x)
6010+
torch.mps.synchronize()
6011+
x.backward(torch.randn_like(x))
6012+
torch.mps.synchronize()
6013+
59756014
# Test random_.to and random_.from
59766015
def test_random(self):
59776016
def helper(shape, low, high, dtype=torch.int32):

Diff for: torch/_C/__init__.pyi.in

+6-2
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
903903
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
904904
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
905905
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
906-
def _is_mps_available() -> _bool: ...
907-
def _is_mps_on_macos_13_or_newer() -> _bool: ...
908906
class _LinalgBackend:
909907
Default: _LinalgBackend
910908
Cusolver: _LinalgBackend
@@ -1200,6 +1198,12 @@ class _TensorBase(metaclass=_TensorMeta):
12001198
# Defined in torch/csrc/multiprocessing/init.cpp
12011199
def _multiprocessing_init() -> None: ...
12021200

1201+
# Defined in torch/csrc/mps/Module.cpp
1202+
def _mps_synchronize() -> None: ...
1203+
def _mps_get_default_generator() -> Generator: ...
1204+
def _is_mps_available() -> _bool: ...
1205+
def _is_mps_on_macos_13_or_newer() -> _bool: ...
1206+
12031207
# Defined in torch/csrc/cuda/Module.cpp
12041208
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
12051209
def _cuda_getCurrentRawStream(device: _int) -> _int: ...

Diff for: torch/csrc/Module.cpp

+2-13
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include <torch/csrc/jit/serialization/pickler.h>
6161
#include <torch/csrc/lazy/python/init.h>
6262
#include <torch/csrc/monitor/python_init.h>
63+
#include <torch/csrc/mps/Module.h>
6364
#include <torch/csrc/multiprocessing/init.h>
6465
#include <torch/csrc/onnx/init.h>
6566
#include <torch/csrc/profiler/python/init.h>
@@ -87,10 +88,6 @@
8788
#endif
8889
#endif
8990

90-
#if defined(USE_MPS)
91-
#include <ATen/mps/MPSDevice.h>
92-
#endif
93-
9491
#if defined(USE_VALGRIND)
9592
#include <callgrind.h>
9693
#endif
@@ -1271,6 +1268,7 @@ PyObject* initModule() {
12711268
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
12721269
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
12731270
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
1271+
THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
12741272
#ifdef USE_CUDA
12751273
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
12761274
#endif
@@ -1593,15 +1591,6 @@ Call this whenever a new thread is created in order to propagate values from
15931591

15941592
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
15951593
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
1596-
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
1597-
py_module.def("_is_mps_on_macos_13_or_newer", []() {
1598-
#ifdef USE_MPS
1599-
return at::mps::is_macos_13_or_newer();
1600-
#else
1601-
return false;
1602-
#endif
1603-
});
1604-
16051594
ASSERT_TRUE(
16061595
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
16071596

Diff for: torch/csrc/mps/Module.cpp

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include <ATen/ATen.h>
2+
#include <c10/util/CallOnce.h>
3+
#include <torch/csrc/Generator.h>
4+
#include <torch/csrc/python_headers.h>
5+
#include <torch/csrc/utils/python_numbers.h>
6+
7+
// pthread.h is included for tracking bad forks
8+
#ifndef WIN32
9+
#include <pthread.h>
10+
#endif
11+
12+
namespace torch {
13+
namespace mps {
14+
15+
namespace {
16+
// True for children forked after mps init
17+
static bool in_bad_fork = false;
18+
19+
// Called in the forked child if mps has already been initialized
20+
static void forked_mps_child() {
21+
in_bad_fork = true;
22+
}
23+
24+
// Should be called before the first mps call.
25+
static void track_bad_mps_fork() {
26+
#ifndef WIN32
27+
static c10::once_flag flag;
28+
c10::call_once(
29+
flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); });
30+
#endif
31+
}
32+
} // namespace
33+
34+
static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
35+
HANDLE_TH_ERRORS
36+
return PyBool_FromLong(in_bad_fork);
37+
END_HANDLE_TH_ERRORS
38+
}
39+
40+
static PyObject* MPSModule_getDefaultMPSGenerator(
41+
PyObject* _unused,
42+
PyObject* noargs) {
43+
HANDLE_TH_ERRORS
44+
track_bad_mps_fork();
45+
return THPGenerator_initDefaultGenerator(
46+
at::detail::getMPSHooks().getDefaultMPSGenerator());
47+
END_HANDLE_TH_ERRORS
48+
}
49+
50+
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
51+
HANDLE_TH_ERRORS
52+
track_bad_mps_fork();
53+
if (at::detail::getMPSHooks().hasMPS()) {
54+
Py_RETURN_TRUE;
55+
} else {
56+
Py_RETURN_FALSE;
57+
}
58+
END_HANDLE_TH_ERRORS
59+
}
60+
61+
static PyObject* MPSModule_isMacOS13orNewer(
62+
PyObject* _unused,
63+
PyObject* noargs) {
64+
HANDLE_TH_ERRORS
65+
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
66+
Py_RETURN_TRUE;
67+
} else {
68+
Py_RETURN_FALSE;
69+
}
70+
END_HANDLE_TH_ERRORS
71+
}
72+
73+
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
74+
HANDLE_TH_ERRORS
75+
at::detail::getMPSHooks().deviceSynchronize();
76+
Py_RETURN_NONE;
77+
END_HANDLE_TH_ERRORS
78+
}
79+
80+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
81+
// cppcoreguidelines-avoid-non-const-global-variables,
82+
// cppcoreguidelines-avoid-c-arrays)
83+
static struct PyMethodDef _MPSModule_methods[] = {
84+
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
85+
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
86+
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
87+
{"_is_mps_on_macos_13_or_newer",
88+
MPSModule_isMacOS13orNewer,
89+
METH_NOARGS,
90+
nullptr},
91+
{"_mps_get_default_generator",
92+
MPSModule_getDefaultMPSGenerator,
93+
METH_NOARGS,
94+
nullptr},
95+
{nullptr}};
96+
97+
PyMethodDef* python_functions() {
98+
return _MPSModule_methods;
99+
}
100+
101+
} // namespace mps
102+
} // namespace torch

Diff for: torch/csrc/mps/Module.h

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/csrc/python_headers.h>
4+
5+
namespace torch {
6+
namespace mps {
7+
8+
PyMethodDef* python_functions();
9+
10+
} // namespace mps
11+
} // namespace torch

Diff for: torch/mps/__init__.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
r"""
2+
This package enables an interface for accessing MPS backend in python
3+
"""
4+
import torch
5+
from .. import Tensor
6+
7+
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
8+
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
9+
10+
# local helper function (not public or exported)
11+
def _get_default_mps_generator() -> torch._C.Generator:
12+
global _default_mps_generator
13+
if _default_mps_generator is None:
14+
_default_mps_generator = torch._C._mps_get_default_generator()
15+
return _default_mps_generator
16+
17+
def synchronize() -> None:
18+
r"""Waits for all kernels in all streams on a MPS device to complete."""
19+
return torch._C._mps_synchronize()
20+
21+
def get_rng_state() -> Tensor:
22+
r"""Returns the random number generator state as a ByteTensor."""
23+
return _get_default_mps_generator().get_state()
24+
25+
def set_rng_state(new_state: Tensor) -> None:
26+
r"""Sets the random number generator state.
27+
28+
Args:
29+
new_state (torch.ByteTensor): The desired state
30+
"""
31+
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
32+
_get_default_mps_generator().set_state(new_state_copy)
33+
34+
def manual_seed(seed: int) -> None:
35+
r"""Sets the seed for generating random numbers.
36+
37+
Args:
38+
seed (int): The desired seed.
39+
"""
40+
# the torch.mps.manual_seed() can be called from the global
41+
# torch.manual_seed() in torch/random.py. So we need to make
42+
# sure mps is available (otherwise we just return without
43+
# erroring out)
44+
if not torch.has_mps:
45+
return
46+
seed = int(seed)
47+
_get_default_mps_generator().manual_seed(seed)
48+
49+
def seed() -> None:
50+
r"""Sets the seed for generating random numbers to a random number."""
51+
_get_default_mps_generator().seed()
52+
53+
__all__ = [
54+
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']

0 commit comments

Comments
 (0)