From ae3f68a17e4417fce00e3a610a6f5ae4a72c9cca Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:13:07 +0100 Subject: [PATCH] Apple-Silicon: Place unsupported Ops on to CPU --- lib/cli/launcher.py | 6 +++++- setup.cfg | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/cli/launcher.py b/lib/cli/launcher.py index ec6cd185e6..f3e3f19af0 100644 --- a/lib/cli/launcher.py +++ b/lib/cli/launcher.py @@ -11,7 +11,7 @@ from lib.gpu_stats import GPUStats from lib.logger import crash_log, log_setup -from lib.utils import FaceswapError, get_torch_version, safe_shutdown, set_backend +from lib.utils import FaceswapError, get_backend, get_torch_version, safe_shutdown, set_backend if T.TYPE_CHECKING: import argparse @@ -47,6 +47,10 @@ def _set_environment_variables(self) -> None: logger.debug("Setting NUMEXPR_MAX_THREADS to %s", allocate) os.environ["NUMEXPR_MAX_THREADS"] = str(allocate) + if get_backend() == "apple_silicon": # Let apple put unsupported ops on the CPU + logger.debug("Enabling unsupported Ops on CPU for Apple Silicon") + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + def _import_script(self) -> Callable: """ Imports the relevant script as indicated by :attr:`_command` from the scripts folder. diff --git a/setup.cfg b/setup.cfg index 52a44811d2..3572057f42 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-tensorboard.*] ignore_missing_imports = True -[mypy-tensorflow.*] +[mypy-torch.*] ignore_missing_imports = True [mypy-tqdm.*] ignore_missing_imports = True