Skip to content

Commit

Permalink
[MPS] Add optional minor argument to is_macos13_or_newer (pytorch…
Browse files Browse the repository at this point in the history
…#95065)

Will be needed if one wants to make accurate XFAIL validation

I.e. `torch.backends.mps.is_macos13_or_newer()` will return True if PyTorch is running on MacOS 13.0 or newer, `torch.backends.mps.is_macos13_or_newer(1)` will return True if running on MacOS 13.1 or newer and `torch.backends.mps.is_macos13_or_newer(2)` will return True  if running on MacOS 13.2 or newer

Do not use 13.3 check as `@available` does not really work for shared libraries

Pull Request resolved: pytorch#95065
Approved by: https://github.com/albanD
  • Loading branch information
malfet authored and jhavukainen committed Mar 15, 2024
1 parent cc9c3f4 commit 74d779f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface {
return false;
}

virtual bool isOnMacOS13orNewer() const {
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
AT_ERROR("MPS backend is not available.");
}

Expand Down
14 changes: 12 additions & 2 deletions aten/src/ATen/mps/MPSHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const {
return at::mps::is_available();
}

bool MPSHooks::isOnMacOS13orNewer() const {
return at::mps::is_macos_13_or_newer();
bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
switch (minor) {
case 0:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
}
}

Allocator* MPSHooks::getMPSDeviceAllocator() const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
bool hasMPS() const override;
bool isOnMacOS13orNewer() const override;
bool isOnMacOS13orNewer(unsigned minor) const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ...
def _mps_currentAllocatedMemory() -> _int: ...
def _mps_driverAllocatedMemory() -> _int: ...
def _mps_is_available() -> _bool: ...
def _mps_is_on_macos_13_or_newer() -> _bool: ...
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...

# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/backends/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def is_available() -> bool:


@_lru_cache()
def is_macos13_or_newer() -> bool:
def is_macos13_or_newer(minor: int = 0) -> bool:
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
return torch._C._mps_is_on_macos_13_or_newer()
return torch._C._mps_is_on_macos_13_or_newer(minor)


# Register prims as implementation of var_mean and group_norm
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/mps/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_isMacOS13orNewer(
PyObject* _unused,
PyObject* noargs) {
static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
THPUtils_assert(
THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()");
auto minor = THPUtils_unpackUInt32(args);
if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
Expand Down Expand Up @@ -124,7 +125,7 @@ static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
METH_O,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,
Expand Down

0 comments on commit 74d779f

Please sign in to comment.